From 09c533353734c3c56aa34f0e63a032ba3c8cacdb Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Wed, 18 Oct 2023 14:36:04 +0200 Subject: [PATCH 01/12] Initial matrix multiplication support. Still some bugs to iron out but works --- Cargo.toml | 1 + examples/mps/main.rs | 3 +- src/lib.rs | 2 + src/mps.rs | 846 +++++++++++++++++++++++++++++++++++++++++-- 4 files changed, 819 insertions(+), 33 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index fe80094..dac664f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,6 +29,7 @@ block = "0.1.6" foreign-types = "0.5" dispatch = { version = "0.2", optional = true } paste = "1" +half = "2.3.1" [dependencies.objc] version = "0.2.4" diff --git a/examples/mps/main.rs b/examples/mps/main.rs index cc01b7a..85f80f5 100644 --- a/examples/mps/main.rs +++ b/examples/mps/main.rs @@ -1,3 +1,4 @@ +use metal::mps::*; use metal::*; use std::ffi::c_void; use std::mem; @@ -67,7 +68,7 @@ fn main() { acceleration_structure.set_vertex_buffer(Some(&vertex_buffer)); acceleration_structure.set_vertex_stride(vertex_stride as u64); acceleration_structure.set_index_buffer(Some(&index_buffer)); - acceleration_structure.set_index_type(mps::MPSDataType::UInt32); + acceleration_structure.set_index_type(mps::UInt32); acceleration_structure.set_triangle_count(1); acceleration_structure.set_usage(mps::MPSAccelerationStructureUsage::None); acceleration_structure.rebuild(); diff --git a/src/lib.rs b/src/lib.rs index b79acf6..c4962eb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,6 +20,8 @@ pub extern crate foreign_types; #[macro_use] pub extern crate paste; +pub extern crate half; + use std::{ borrow::{Borrow, ToOwned}, marker::PhantomData, diff --git a/src/mps.rs b/src/mps.rs index edd4936..8b0a3b9 100644 --- a/src/mps.rs +++ b/src/mps.rs @@ -6,10 +6,16 @@ // copied, modified, or distributed except according to those terms. use super::*; - +use half::{bf16, f16}; +use objc::rc::autoreleasepool; use objc::runtime::{BOOL, YES}; +use std::fmt::Debug; +use std::hash::Hash; -#[cfg_attr(feature = "link", link(name = "MetalPerformanceShaders", kind = "framework"))] +#[cfg_attr( + feature = "link", + link(name = "MetalPerformanceShaders", kind = "framework") +)] extern "C" { fn MPSSupportsMTLDevice(device: *const std::ffi::c_void) -> BOOL; } @@ -129,33 +135,6 @@ bitflags! { } } -/// A common bit for all floating point data types. -const MPSDataTypeFloatBit: isize = 0x10000000; -const MPSDataTypeSignedBit: isize = 0x20000000; -const MPSDataTypeNormalizedBit: isize = 0x40000000; - -/// See -pub enum MPSDataType { - Invalid = 0, - - Float32 = MPSDataTypeFloatBit | 32, - Float16 = MPSDataTypeFloatBit | 16, - - // Signed integers. - Int8 = MPSDataTypeSignedBit | 8, - Int16 = MPSDataTypeSignedBit | 16, - Int32 = MPSDataTypeSignedBit | 32, - - // Unsigned integers. Range: [0, UTYPE_MAX] - UInt8 = 8, - UInt16 = 16, - UInt32 = 32, - - // Unsigned normalized. Range: [0, 1.0] - Unorm1 = MPSDataTypeNormalizedBit | 1, - Unorm8 = MPSDataTypeNormalizedBit | 8, -} - /// A kernel that performs intersection tests between rays and geometry. /// /// See @@ -202,7 +181,7 @@ impl RayIntersectorRef { unsafe { msg_send![self, setRayDataType: ty] } } - pub fn set_ray_index_data_type(&self, ty: MPSDataType) { + pub fn set_ray_index_data_type(&self, ty: T) { unsafe { msg_send![self, setRayIndexDataType: ty] } } @@ -345,8 +324,8 @@ impl PolygonAccelerationStructureRef { unsafe { msg_send![self, setIndexBufferOffset: offset] } } - pub fn set_index_type(&self, data_type: MPSDataType) { - unsafe { msg_send![self, setIndexType: data_type] } + pub fn set_index_type(&self, _data_type: T) { + unsafe { msg_send![self, setIndexType: T::ENCODING] } } pub fn set_mask_buffer(&self, buffer: Option<&BufferRef>) { @@ -570,3 +549,806 @@ pub struct MPSIntersectionDistancePrimitiveIndexCoordinates { /// if the intersection type is `MPSIntersectionTypeAny`. pub coordinates: [f32; 2], } + +/// A value to specify a type of data. +/// +/// See . +pub trait MPSDataType: Clone + Copy + PartialEq + Eq + Debug + Hash { + type Type: Default + Clone + Copy + PartialEq + Debug + Sized; + const ENCODING: u32; + + /// See . + const SIZE: u32 = ((Self::ENCODING & 0xFFFF) >> 3); +} + +/// A common bit for all floating point data types. Zero for integer types +const MPS_FLOATBIT_ENCODING: u32 = 0x10000000; +/// A common bit for all complex point data types. Zero for integer types +const MPS_COMPLEXBIT_ENCODING: u32 = MPS_FLOATBIT_ENCODING | 0x01000000; +/// A common bit for all signed data types +const MPS_SIGNEDBIT_ENCODING: u32 = 0x20000000; +/// A common bit for all alternate encoding data types +const MPS_ALTERNATE_ENCODING: u32 = 0x80000000; +/// A common bit for all normalized data types. +/// If set, the value of the shall be interpreted as value / UNORM_TYPE_MAX +/// Normalized values have range [0, 1.0] if unsigned and [-1,1] if signed. +/// SNORM_TYPE_MIN is interpreted as SNORM_TYPE_MIN+1 per standard Metal rules. +const MPS_NORMALIZEDBIT_ENCODING: u32 = 0x40000000; +macro_rules! mps_datatype { + ($dt:ident, $dt_ty:ty, $encoding:expr, $comment:expr) => { + #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] + #[doc=$comment] + pub struct $dt; + + impl MPSDataType for $dt { + type Type = $dt_ty; + const ENCODING: u32 = $encoding; + } + }; +} +mps_datatype!(Invalid, (), 0, "An invalid data type."); + +mps_datatype!( + Float32, + f32, + MPS_FLOATBIT_ENCODING | 32, + "32-bit floating point (single-precision)." +); +mps_datatype!( + Float16, + f16, + MPS_FLOATBIT_ENCODING | 16, + "16-bit floating point (half-precision). (IEEE-754-2008 float16 exchange format)" +); + +mps_datatype!( + ComplexFloat32, + (f32, f32), + MPS_COMPLEXBIT_ENCODING | 64, + "Complex number composed of two 32-bit floating point numbers (single-precision)." +); +mps_datatype!(ComplexFloat16, (f16, f16), MPS_COMPLEXBIT_ENCODING | 32, "Complex number composed of two 16-bit floating point numbers (half-precision). (IEEE-754-2008 float16 exchange format)"); + +mps_datatype!( + Int8, + i8, + MPS_SIGNEDBIT_ENCODING | 8, + "Signed 8-bit integer." +); +mps_datatype!( + Int16, + i16, + MPS_SIGNEDBIT_ENCODING | 16, + "Signed 16-bit integer." +); +mps_datatype!( + Int32, + i32, + MPS_SIGNEDBIT_ENCODING | 32, + "Signed 32-bit integer." +); +mps_datatype!( + Int64, + i64, + MPS_SIGNEDBIT_ENCODING | 64, + "Signed 64-bit integer." +); + +mps_datatype!(UInt8, u8, 8, "Unsigned 8-bit integer. Not normalized"); +mps_datatype!(UInt16, u16, 16, "Unsigned 16-bit integer. Not normalized"); +mps_datatype!(UInt32, u32, 32, "Unsigned 32-bit integer. Not normalized"); +mps_datatype!(UInt64, u64, 64, "Unsigned 64-bit integer. Not normalized"); + +mps_datatype!( + Bool, + bool, + MPS_ALTERNATE_ENCODING | 8, + "Boolean as 8-bit integer. Not normalized." +); +mps_datatype!( + BF16, + bf16, + MPS_ALTERNATE_ENCODING | MPS_FLOATBIT_ENCODING | 16, + "Boolean as 8-bit integer. Not normalized." +); +mps_datatype!( + UNorm1, + bool, + MPS_NORMALIZEDBIT_ENCODING | 1, + "Unsigned 1-bit normalized value." +); +mps_datatype!( + UNorm8, + u8, + MPS_NORMALIZEDBIT_ENCODING | 8, + "Unsigned 8-bit normalized value." +); + +/// See +pub enum MPSMatrixDescriptor {} + +foreign_obj_type! { + type CType = MPSMatrixDescriptor; + pub struct MatrixDescriptorObject; + type ParentType = NsObject; +} + +impl MatrixDescriptorObject { + fn init_single( + rows: NSUInteger, + columns: NSUInteger, + row_bytes: NSUInteger, + data_type: u32, + ) -> Self { + unsafe { + msg_send![ + class!(MPSMatrixDescriptor), + matrixDescriptorWithRows : rows + columns : columns + rowBytes : row_bytes + dataType : data_type + ] + } + } + + fn init_multiple( + rows: NSUInteger, + columns: NSUInteger, + matrices: NSUInteger, + matrix_bytes: NSUInteger, + row_bytes: NSUInteger, + data_type: u32, + ) -> Self { + unsafe { + msg_send![ + class!(MPSMatrixDescriptor), + matrixDescriptorWithRows : rows + columns : columns + matrices : matrices + rowBytes : row_bytes + matrixBytes : matrix_bytes + dataType : data_type + ] + } + } + + fn row_bytes_for_columns(columns: NSUInteger, data_type: u32) -> NSUInteger { + unsafe { + msg_send![ + class!(MPSMatrixDescriptor), + rowBytesForColumns : columns + dataType : data_type + ] + } + } +} + +#[derive(Debug)] +pub struct MatrixDescriptor { + pub object: MatrixDescriptorObject, + pub rows: NSUInteger, + pub columns: NSUInteger, + pub matrices: NSUInteger, + pub row_bytes: NSUInteger, + pub matrix_bytes: NSUInteger, + pub _marker: PhantomData, +} + +impl MatrixDescriptor { + pub fn single(rows: NSUInteger, columns: NSUInteger) -> Self { + // The number of bytes between starting elements of consecutive rows. + let row_bytes = Self::row_bytes_for_columns(columns); + + let object = MatrixDescriptorObject::init_single(rows, columns, row_bytes, T::ENCODING); + MatrixDescriptor { + object, + rows, + columns, + row_bytes, + matrices: 1, + matrix_bytes: row_bytes, + _marker: PhantomData, + } + } + + pub fn multiple(rows: NSUInteger, columns: NSUInteger, matrices: NSUInteger) -> Self { + // The number of bytes between starting elements of consecutive rows. + let row_bytes = Self::row_bytes_for_columns(columns); + // The number of bytes between starting elements of consecutive matrices. + let matrix_bytes = row_bytes * rows; + + let object = MatrixDescriptorObject::init_multiple( + rows, + columns, + matrices, + matrix_bytes, + row_bytes, + T::ENCODING, + ); + MatrixDescriptor { + object, + rows, + columns, + matrices, + row_bytes, + matrix_bytes, + _marker: PhantomData, + } + } + + pub fn row_bytes_for_columns(columns: NSUInteger) -> NSUInteger { + MatrixDescriptorObject::row_bytes_for_columns(columns, T::ENCODING) + } + + pub fn required_buffer_size(&self) -> NSUInteger { + self.matrices * self.matrix_bytes + self.rows * self.row_bytes + //+ T::SIZE as NSUInteger * self.columns + } +} + +/// See +pub enum MPSMatrix {} + +foreign_obj_type! { + type CType = MPSMatrix; + pub struct MatrixObject; + type ParentType = NsObject; +} + +impl MatrixObject { + fn init_with_device_descriptor( + device: &DeviceRef, + descriptor: &MatrixDescriptorObjectRef, + ) -> Option { + unsafe { + let matrix: MatrixObject = msg_send![class!(MPSMatrix), alloc]; + let ptr: *mut Object = msg_send![ + matrix.as_ref(), + initWithDevice : device + descriptor : descriptor + ]; + if ptr.is_null() { + None + } else { + Some(matrix) + } + } + } + + fn init_with_buffer_descriptor( + buffer: &BufferRef, + descriptor: &MatrixDescriptorObjectRef, + ) -> Option { + // assert!(buffer.length() >= descriptor.rowBytes() * descriptor.rows()); + // assert_eq!(buffer.length() % descriptor.rowBytes(), 0); + // assert_eq!(buffer.device(), descriptor.device()); + unsafe { + let matrix: MatrixObject = msg_send![class!(MPSMatrix), alloc]; + let ptr: *mut Object = msg_send![ + matrix.as_ref(), + initWithBuffer : buffer + descriptor: descriptor + ]; + if ptr.is_null() { + None + } else { + Some(matrix) + } + } + } +} + +impl MatrixObjectRef { + pub fn device(&self) -> &DeviceRef { + unsafe { msg_send![self, device] } + } + + pub fn rows(&self) -> u64 { + unsafe { msg_send![self, rows] } + } + + pub fn columns(&self) -> u64 { + unsafe { msg_send![self, columns] } + } + + pub fn row_bytes(&self) -> u64 { + unsafe { msg_send![self, rowBytes] } + } + + pub fn data_type(&self) -> u32 { + unsafe { msg_send![self, dataType] } + } + + pub fn data(&self) -> *mut std::ffi::c_void { + unsafe { msg_send![self, data] } + } + + pub fn resource_size(&self) -> u64 { + unsafe { msg_send![self, resourceSize] } + } +} + +#[derive(Debug)] +pub struct Matrix { + pub object: MatrixObject, + pub entries: Vec, // row-major order + pub rows: u64, + pub columns: u64, +} + +impl Matrix { + pub fn init(device: &DeviceRef, descriptor: &MatrixDescriptor) -> Self { + let object = MatrixObject::init_with_device_descriptor(device, &descriptor.object).unwrap(); + let entries = vec![T::Type::default(); (&descriptor.rows * &descriptor.columns) as usize]; + Matrix { + object, + entries, + rows: descriptor.rows, + columns: descriptor.columns, + } + } + + pub fn init_with_buffer(buffer: &BufferRef, descriptor: &MatrixDescriptor) -> Self { + let object = MatrixObject::init_with_buffer_descriptor(buffer, &descriptor.object).unwrap(); + let entries = vec![T::Type::default(); (&descriptor.rows * &descriptor.columns) as usize]; + Matrix { + object, + entries, + rows: descriptor.rows, + columns: descriptor.columns, + } + } +} + +/// A kernel for matrix multiplication. +/// +/// Computes the following operation: +/// +/// `C = alpha * op(A) * op(B) + beta * C` +/// +/// Where A, B, and C are matrices represented by MPSMatrix objects, and alpha and beta are scalar values of the same data type as the values of C. A and B may each have an optional transposition operation applied. +/// +/// Matrices A, B, and C are also referred to as the left input matrix, the right input matrix, and the result matrix respectively. +/// +/// See . +pub enum MPSMatrixMultiplication {} + +foreign_obj_type! { + type CType = MPSMatrixMultiplication; + pub struct MatrixMultiplicationKernel; + type ParentType = Kernel; +} +impl MatrixMultiplicationKernel { + pub fn from_device(device: &DeviceRef) -> Option { + unsafe { + let kernel: MatrixMultiplicationKernel = + msg_send![class!(MPSMatrixMultiplication), alloc]; + let ptr: *mut Object = msg_send![kernel.as_ref(), initWithDevice: device]; + if ptr.is_null() { + None + } else { + Some(kernel) + } + } + } + + pub fn init( + device: &DeviceRef, + transpose_left: bool, + transpose_right: bool, + result_rows: NSUInteger, + result_columns: NSUInteger, + interior_columns: NSUInteger, + alpha: f32, + beta: f32, + ) -> Option { + unsafe { + let kernel: MatrixMultiplicationKernel = + msg_send![class!(MPSMatrixMultiplication), alloc]; + let ptr: *mut Object = msg_send![ + kernel.as_ref(), + initWithDevice : device + transposeLeft : transpose_left + transposeRight : transpose_right + resultRows : result_rows + resultColumns : result_columns + interiorColumns : interior_columns + alpha : alpha + beta : beta + ]; + if ptr.is_null() { + None + } else { + Some(kernel) + } + } + } + + fn init_simple( + device: &DeviceRef, + result_rows: NSUInteger, + result_columns: NSUInteger, + interior_columns: NSUInteger, + ) -> Option { + unsafe { + let kernel: MatrixMultiplicationKernel = + msg_send![class!(MPSMatrixMultiplication), alloc]; + let ptr: *mut Object = msg_send![ + kernel.as_ref(), + initWithDevice : device + resultRows : result_rows + resultColumns : result_columns + interiorColumns : interior_columns + ]; + if ptr.is_null() { + None + } else { + Some(kernel) + } + } + } +} + +#[derive(Debug)] +struct MatrixMultiplication { + kernel: MatrixMultiplicationKernel, + transpose_left: bool, + transpose_right: bool, + result_rows: NSUInteger, + result_columns: NSUInteger, + interior_columns: NSUInteger, + alpha: f32, + beta: f32, +} + +/// Helper trait used indicates that a type constraint is valid. +trait Valid {} + +/// Helper struct used to indicate a valid matrix multiplication input type. +struct MatMulInput { + _marker: PhantomData, +} + +/// Input data type must be one of MPSDataTypeFloat32, MPSDataTypeFloat16, MPSDataTypeInt8, +/// or MPSDataTypeInt16 +impl Valid for MatMulInput {} +impl Valid for MatMulInput {} +impl Valid for MatMulInput {} +impl Valid for MatMulInput {} + +/// Helper struct used to indicate a valid matrix multiplication result type. +struct MatMulResult { + _marker: PhantomData, +} + +/// Only MPSDataTypeFloat16 and MPSDataTypeFloat32 are supported for the result matrix. +impl Valid for MatMulResult {} +impl Valid for MatMulResult {} + +/// Helper struct used to indicate valid matrix multiplication types. +struct MatMulSpecification +where + Left: MPSDataType, + MatMulInput: Valid, + Right: MPSDataType, + MatMulInput: Valid, + Result: MPSDataType, + MatMulResult: Valid, +{ + _marker: PhantomData<(Left, Right, Result)>, +} + +/// Mixed input matrix multiplication is only for +impl Valid for MatMulSpecification {} + +/// All valid input types can produce a MPSDataTypeFloat32 result. +impl Valid for MatMulSpecification +where + Input: MPSDataType, + MatMulInput: Valid, +{ +} + +/// These input types can produce a MPSDataTypeFloat16 result. +impl Valid for MatMulSpecification {} +impl Valid for MatMulSpecification {} +impl Valid for MatMulSpecification {} + +impl MatrixMultiplication { + pub fn init( + device: &DeviceRef, + transpose_left: bool, + transpose_right: bool, + result_rows: NSUInteger, + result_columns: NSUInteger, + interior_columns: NSUInteger, + alpha: f32, + beta: f32, + ) -> Option { + assert!(result_rows > 0); + assert!(result_columns > 0); + assert!(interior_columns > 0); + if let Some(kernel) = MatrixMultiplicationKernel::init( + device, + transpose_left, + transpose_right, + result_rows, + result_columns, + interior_columns, + alpha, + beta, + ) { + return Some(MatrixMultiplication { + kernel, + transpose_left, + transpose_right, + result_rows, + result_columns, + interior_columns, + alpha, + beta, + }); + } + None + } + + pub fn init_simple( + device: &DeviceRef, + result_rows: NSUInteger, + result_columns: NSUInteger, + interior_columns: NSUInteger, + ) -> Option { + assert!(result_rows > 0); + assert!(result_columns > 0); + assert!(interior_columns > 0); + if let Some(kernel) = MatrixMultiplicationKernel::init_simple( + device, + result_rows, + result_columns, + interior_columns, + ) { + return Some(MatrixMultiplication { + kernel, + transpose_left: false, + transpose_right: false, + result_rows, + result_columns, + interior_columns, + alpha: 1.0, + beta: 0.0, + }); + } + None + } + /// Encode the kernel to the given command buffer. + /// * `command_buffer` - The command buffer to encode the kernel to. + /// * `left_matrix` - The left matrix to multiply. + /// * `right_matrix` - The right matrix to multiply. + /// * `result_matrix` - The matrix to store the result in. + pub fn encode_to_command_buffer( + &self, + command_buffer: &CommandBufferRef, + left_matrix: &Matrix, + right_matrix: &Matrix, + result_matrix: &Matrix, + ) where + T: MPSDataType, + U: MPSDataType, + V: MPSDataType, + MatMulInput: Valid, + MatMulInput: Valid, + MatMulResult: Valid, + MatMulSpecification: Valid, + { + // Certain constraints apply to the sizes of the matrices depending on the transposition + // operations and sizes requested at initialization time as well as the origins at the time + // this routine is called: + // + // The left input matrix must be large enough to hold an array of size resultRows x interiorColumns + // elements beginning at leftMatrixOrigin. + assert!(left_matrix.rows * left_matrix.columns >= self.result_rows * self.interior_columns); + // The right input matrix must be large enough to hold an array of size + // interiorColumns x resultColumns elements beginning at rightMatrixOrigin. + assert!( + right_matrix.rows * right_matrix.columns >= self.interior_columns * self.result_columns + ); + // The result matrix must be large enough to hold an array of size resultRows x resultColumns + // elements beginning at resultMatrixOrigin. + assert!( + result_matrix.rows * result_matrix.columns >= self.result_rows * self.result_columns + ); + + // Each matrix within the range specified by batchStart and batchSize, which also specifies + // a valid set of matrices within leftMatrix, rightMatrix, and resultMatrix, will + // be processed. + + self.kernel.encode_to_command_buffer( + command_buffer, + &left_matrix.object, + &right_matrix.object, + &result_matrix.object, + ); + } +} + +impl MatrixMultiplicationKernelRef { + pub fn encode_to_command_buffer( + &self, + command_buffer: &CommandBufferRef, + left_matrix: &MatrixObjectRef, + right_matrix: &MatrixObjectRef, + result_matrix: &MatrixObjectRef, + ) { + unsafe { + let _: () = msg_send!( + *self, + encodeToCommandBuffer : command_buffer + leftMatrix : left_matrix + rightMatrix : right_matrix + resultMatrix : result_matrix + ); + } + } +} + +#[derive(Debug)] +pub struct MatrixBuffer { + buffer: Buffer, + rows: NSUInteger, + columns: NSUInteger, + _marker: PhantomData, +} + +impl MatrixBuffer { + pub fn new( + device: &DeviceRef, + descriptor: &MatrixDescriptor, + options: MTLResourceOptions, + ) -> Self { + let buffer = device.new_buffer(descriptor.required_buffer_size(), options); + MatrixBuffer { + buffer, + rows: descriptor.rows, + columns: descriptor.columns, + _marker: PhantomData, + } + } + + pub fn new_with_data( + device: &DeviceRef, + entries: &Vec, + descriptor: &MatrixDescriptor, + options: MTLResourceOptions, + ) -> Self { + let buffer = device.new_buffer_with_data( + entries.as_ptr().cast(), + descriptor.required_buffer_size(), + options, + ); + MatrixBuffer { + buffer, + rows: descriptor.rows, + columns: descriptor.columns, + _marker: PhantomData, + } + } + + pub fn contents(&self) -> Vec { + let contents = self.buffer.contents() as *const T::Type; + let sl: &[T::Type] = + unsafe { std::slice::from_raw_parts(contents, (self.rows * self.columns) as usize) }; + sl.to_vec() + } +} + +fn matmul( + transpose_left: bool, + transpose_right: bool, + left_descriptor: MatrixDescriptor, + left_entries: &Vec, + right_descriptor: MatrixDescriptor, + right_entries: &Vec, + alpha: f32, + beta: f32, +) -> Vec +where + T: MPSDataType, + U: MPSDataType, + V: MPSDataType, + MatMulInput: Valid, + MatMulInput: Valid, + MatMulResult: Valid, + MatMulSpecification: Valid, +{ + // For matrix multiplication, the number of columns in the first matrix must be equal to + // the number of rows in the second matrix. + // The result matrix has the number of rows of the first and the number of columns of the + // second matrix. + let result_descriptor = + MatrixDescriptor::::single(left_descriptor.rows, right_descriptor.columns); + + let device = Device::system_default().expect("No device found"); + let matrix_multiplication = MatrixMultiplication::init_simple( + &device, + left_descriptor.rows, + right_descriptor.columns, + left_descriptor.columns, + ) + .unwrap(); + + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + + let left_buffer = MatrixBuffer::new_with_data( + &device, + left_entries, + &left_descriptor, + MTLResourceOptions::StorageModeShared, + ); + let right_buffer = MatrixBuffer::new_with_data( + &device, + right_entries, + &right_descriptor, + MTLResourceOptions::StorageModeShared, + ); + let result_buffer = MatrixBuffer::new( + &device, + &result_descriptor, + MTLResourceOptions::StorageModeShared, + ); + + let left_matrix = Matrix::init_with_buffer(&left_buffer.buffer, &left_descriptor); + + let right_matrix = Matrix::init_with_buffer(&right_buffer.buffer, &right_descriptor); + + let result_matrix = Matrix::init_with_buffer(&result_buffer.buffer, &result_descriptor); + + matrix_multiplication.encode_to_command_buffer( + &command_buffer, + &left_matrix, + &right_matrix, + &result_matrix, + ); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + let result = result_buffer.contents(); + result +} + +#[cfg(test)] +mod tests { + use super::*; + use rand::Rng; + + fn random_matrix(rows: usize, columns: usize) -> (Vec, NSUInteger, NSUInteger) { + let mut rng = rand::thread_rng(); + let mut entries = vec![0.0; rows * columns]; + for i in 0..rows { + for j in 0..columns { + entries[i * columns + j] = rng.gen(); + } + } + (entries, rows as NSUInteger, columns as NSUInteger) + } + + #[test] + fn test_matrix_multiplication() { + let (left_entries, l_rows, l_columns) = random_matrix(1024, 1024); + let (right_entries, r_rows, r_columns) = random_matrix(1024, 1024); + autoreleasepool(|| { + let result = matmul( + false, + false, + MatrixDescriptor::::single(l_rows, l_columns), + &left_entries, + MatrixDescriptor::::single(r_rows, r_columns), + &right_entries, + 1.0, + 0.0, + ); + + println!("{:?}", result.len()); + }); + } +} From 4f4df066a1c425e9f507b6ae5704aa5af18d478e Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Fri, 20 Oct 2023 11:01:31 +0200 Subject: [PATCH 02/12] Remove over the top abstractions. Changed test to example --- Cargo.toml | 8 +- examples/mps/matrix-multiplication/main.rs | 37 + examples/mps/{ => ray-intersection}/main.rs | 5 +- .../mps/{ => ray-intersection}/shaders.metal | 0 .../{ => ray-intersection}/shaders.metallib | Bin src/mps.rs | 740 ++++++++---------- 6 files changed, 378 insertions(+), 412 deletions(-) create mode 100644 examples/mps/matrix-multiplication/main.rs rename examples/mps/{ => ray-intersection}/main.rs (97%) rename examples/mps/{ => ray-intersection}/shaders.metal (100%) rename examples/mps/{ => ray-intersection}/shaders.metallib (100%) diff --git a/Cargo.toml b/Cargo.toml index dac664f..29a6037 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -77,7 +77,13 @@ name = "compute" path = "examples/compute/main.rs" [[example]] -name = "mps" +name = "mps-matrix-multiplication" +path = "examples/mps/matrix-multiplication/main.rs" +required-features = ["mps"] + +[[example]] +name = "mps-ray-intersection" +path = "examples/mps/ray-intersection/main.rs" required-features = ["mps"] [[example]] diff --git a/examples/mps/matrix-multiplication/main.rs b/examples/mps/matrix-multiplication/main.rs new file mode 100644 index 0000000..ed54c46 --- /dev/null +++ b/examples/mps/matrix-multiplication/main.rs @@ -0,0 +1,37 @@ +use metal::mps::*; +use metal::*; + +fn generate_matrix() -> Matrix +where + T: MPSDataType, + MatMulInput: Valid, +{ + Matrix { + entries: (1..=ROWS * COLS).map(|i| T::from_f64(i as f64)).collect(), + rows: ROWS as NSUInteger, + columns: COLS as NSUInteger, + } +} + +fn main() { + type A = Float32; + type B = Float32; + type C = Float32; + const M: usize = 1; + const N: usize = 1; + const K: usize = 5; + + let transpose_left = false; + let transpose_right = false; + let alpha = 1.0; + let beta = 0.0; + + let left = generate_matrix::(); + let right = generate_matrix::(); + + println!("{left:?}"); + println!("{right:?}"); + + let result = matrix_multiplication(transpose_left, transpose_right, &left, &right, alpha, beta); + println!("{result:?}"); +} diff --git a/examples/mps/main.rs b/examples/mps/ray-intersection/main.rs similarity index 97% rename from examples/mps/main.rs rename to examples/mps/ray-intersection/main.rs index 85f80f5..ed79411 100644 --- a/examples/mps/main.rs +++ b/examples/mps/ray-intersection/main.rs @@ -1,4 +1,3 @@ -use metal::mps::*; use metal::*; use std::ffi::c_void; use std::mem; @@ -15,8 +14,8 @@ type Intersection = mps::MPSIntersectionDistancePrimitiveIndexCoordinates; fn main() { let device = Device::system_default().expect("No device found"); - let library_path = - std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("examples/mps/shaders.metallib"); + let library_path = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("examples/mps/ray-intersection/shaders.metallib"); let library = device .new_library_with_file(library_path) .expect("Failed to load shader library"); diff --git a/examples/mps/shaders.metal b/examples/mps/ray-intersection/shaders.metal similarity index 100% rename from examples/mps/shaders.metal rename to examples/mps/ray-intersection/shaders.metal diff --git a/examples/mps/shaders.metallib b/examples/mps/ray-intersection/shaders.metallib similarity index 100% rename from examples/mps/shaders.metallib rename to examples/mps/ray-intersection/shaders.metallib diff --git a/src/mps.rs b/src/mps.rs index 8b0a3b9..c0c8e00 100644 --- a/src/mps.rs +++ b/src/mps.rs @@ -7,7 +7,6 @@ use super::*; use half::{bf16, f16}; -use objc::rc::autoreleasepool; use objc::runtime::{BOOL, YES}; use std::fmt::Debug; use std::hash::Hash; @@ -325,7 +324,7 @@ impl PolygonAccelerationStructureRef { } pub fn set_index_type(&self, _data_type: T) { - unsafe { msg_send![self, setIndexType: T::ENCODING] } + unsafe { msg_send![self, setIndexType: T::CODE] } } pub fn set_mask_buffer(&self, buffer: Option<&BufferRef>) { @@ -555,10 +554,14 @@ pub struct MPSIntersectionDistancePrimitiveIndexCoordinates { /// See . pub trait MPSDataType: Clone + Copy + PartialEq + Eq + Debug + Hash { type Type: Default + Clone + Copy + PartialEq + Debug + Sized; - const ENCODING: u32; + const CODE: u32; /// See . - const SIZE: u32 = ((Self::ENCODING & 0xFFFF) >> 3); + const SIZE: u32 = ((Self::CODE & 0xFFFF) >> 3); + + fn from_f64(v: f64) -> Self::Type; + + fn to_f64(v: Self::Type) -> f64; } /// A common bit for all floating point data types. Zero for integer types @@ -574,106 +577,261 @@ const MPS_ALTERNATE_ENCODING: u32 = 0x80000000; /// Normalized values have range [0, 1.0] if unsigned and [-1,1] if signed. /// SNORM_TYPE_MIN is interpreted as SNORM_TYPE_MIN+1 per standard Metal rules. const MPS_NORMALIZEDBIT_ENCODING: u32 = 0x40000000; -macro_rules! mps_datatype { - ($dt:ident, $dt_ty:ty, $encoding:expr, $comment:expr) => { - #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] - #[doc=$comment] - pub struct $dt; +macro_rules! mps_datatype_impl { + ($dt:ident, $dt_ty:ty, $code:expr, $from_f64:expr, $to_f64:expr) => { impl MPSDataType for $dt { type Type = $dt_ty; - const ENCODING: u32 = $encoding; + const CODE: u32 = $code; + + fn from_f64(v: f64) -> Self::Type { + $from_f64(v) + } + + fn to_f64(v: Self::Type) -> f64 { + $to_f64(v) + } } }; } -mps_datatype!(Invalid, (), 0, "An invalid data type."); +macro_rules! mps_datatype { + ($dt:ident, $dt_ty:ty, $code:expr, $from_f64:expr, $to_f64:expr, $comment:expr) => { + #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] + #[doc=$comment] + pub struct $dt; + mps_datatype_impl!($dt, $dt_ty, $code, $from_f64, $to_f64); + }; + ($dt:ident, $dt_ty:ty, $code:expr, $from_f64:expr, $to_f64:expr) => { + #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] + pub struct $dt; + + mps_datatype_impl!($dt, $dt_ty, $code, $from_f64, $to_f64); + }; +} +mps_datatype!(Invalid, (), 0, |_: f64| (), |_: ()| 0.0); mps_datatype!( Float32, f32, MPS_FLOATBIT_ENCODING | 32, + |v: f64| v as f32, + |v: f32| v as f64, "32-bit floating point (single-precision)." ); mps_datatype!( Float16, f16, MPS_FLOATBIT_ENCODING | 16, + |v: f64| f16::from_f64(v), + |v: f16| v.to_f64(), "16-bit floating point (half-precision). (IEEE-754-2008 float16 exchange format)" ); +fn unpack_f32_tuple(packed: f64) -> (f32, f32) { + let packed_bits = packed.to_bits(); + let f1_bits = (packed_bits >> 32) as u32; + let f2_bits = (packed_bits & 0xFFFFFFFF) as u32; + (f32::from_bits(f1_bits), f32::from_bits(f2_bits)) +} + +fn pack_f32_tuple((f1, f2): (f32, f32)) -> f64 { + let f1_bits = f1.to_bits(); + let f2_bits = f2.to_bits(); + let packed = ((f1_bits as u64) << 32) | (f2_bits as u64); + f64::from_bits(packed) +} + mps_datatype!( ComplexFloat32, (f32, f32), MPS_COMPLEXBIT_ENCODING | 64, + unpack_f32_tuple, + pack_f32_tuple, "Complex number composed of two 32-bit floating point numbers (single-precision)." ); -mps_datatype!(ComplexFloat16, (f16, f16), MPS_COMPLEXBIT_ENCODING | 32, "Complex number composed of two 16-bit floating point numbers (half-precision). (IEEE-754-2008 float16 exchange format)"); + +fn unpack_f16_tuple(packed: f64) -> (f16, f16) { + let packed_bits = packed.to_bits(); + let f1_bits = (packed_bits >> 16) as u16; + let f2_bits = (packed_bits & 0xFFFF) as u16; + (f16::from_bits(f1_bits), f16::from_bits(f2_bits)) +} + +fn pack_f16_tuple((f1, f2): (f16, f16)) -> f64 { + let f1_bits = f1.to_bits(); + let f2_bits = f2.to_bits(); + let packed = ((f1_bits as u64) << 16) | (f2_bits as u64); + f64::from_bits(packed) +} +mps_datatype!( + ComplexFloat16, + (f16, f16), + MPS_COMPLEXBIT_ENCODING | 32, + unpack_f16_tuple, + pack_f16_tuple, + "Complex number composed of two 16-bit floating point numbers (half-precision). (IEEE-754-2008 float16 exchange format)" +); mps_datatype!( Int8, i8, MPS_SIGNEDBIT_ENCODING | 8, + |v: f64| v as i8, + |v: i8| v as f64, "Signed 8-bit integer." ); mps_datatype!( Int16, i16, MPS_SIGNEDBIT_ENCODING | 16, + |v: f64| v as i16, + |v: i16| v as f64, "Signed 16-bit integer." ); mps_datatype!( Int32, i32, MPS_SIGNEDBIT_ENCODING | 32, + |v: f64| v as i32, + |v: i32| v as f64, "Signed 32-bit integer." ); mps_datatype!( Int64, i64, MPS_SIGNEDBIT_ENCODING | 64, + |v: f64| v as i64, + |v: i64| v as f64, "Signed 64-bit integer." ); - -mps_datatype!(UInt8, u8, 8, "Unsigned 8-bit integer. Not normalized"); -mps_datatype!(UInt16, u16, 16, "Unsigned 16-bit integer. Not normalized"); -mps_datatype!(UInt32, u32, 32, "Unsigned 32-bit integer. Not normalized"); -mps_datatype!(UInt64, u64, 64, "Unsigned 64-bit integer. Not normalized"); - +mps_datatype!( + UInt8, + u8, + 8, + |v: f64| v as u8, + |v: u8| v as f64, + "Unsigned 8-bit integer. Not normalized" +); +mps_datatype!( + UInt16, + u16, + 16, + |v: f64| v as u16, + |v: u16| v as f64, + "Unsigned 16-bit integer. Not normalized" +); +mps_datatype!( + UInt32, + u32, + 32, + |v: f64| v as u32, + |v: u32| v as f64, + "Unsigned 32-bit integer. Not normalized" +); +mps_datatype!( + UInt64, + u64, + 64, + |v: f64| v as u64, + |v: u64| v as f64, + "Unsigned 64-bit integer. Not normalized" +); mps_datatype!( Bool, bool, MPS_ALTERNATE_ENCODING | 8, + |v: f64| v != 0.0, + |v: bool| if v { 1.0 } else { 0.0 }, "Boolean as 8-bit integer. Not normalized." ); mps_datatype!( BF16, bf16, MPS_ALTERNATE_ENCODING | MPS_FLOATBIT_ENCODING | 16, + |v: f64| bf16::from_f64(v), + |v: bf16| v.to_f64(), "Boolean as 8-bit integer. Not normalized." ); mps_datatype!( UNorm1, bool, MPS_NORMALIZEDBIT_ENCODING | 1, + |v: f64| v != 0.0, + |v: bool| if v { 1.0 } else { 0.0 }, "Unsigned 1-bit normalized value." ); mps_datatype!( UNorm8, u8, MPS_NORMALIZEDBIT_ENCODING | 8, + |v: f64| v as u8, + |v: u8| v as f64, "Unsigned 8-bit normalized value." ); +/// Helper trait used indicates that a type constraint is valid. +pub trait Valid {} + +/// Helper struct used to indicate a valid matrix multiplication input type. +pub struct MatMulInput { + _marker: PhantomData, +} + +/// Input data type must be one of MPSDataTypeFloat32, MPSDataTypeFloat16, MPSDataTypeInt8, +/// or MPSDataTypeInt16 +impl Valid for MatMulInput {} +impl Valid for MatMulInput {} +impl Valid for MatMulInput {} +impl Valid for MatMulInput {} + +/// Helper struct used to indicate a valid matrix multiplication result type. +pub struct MatMulResult { + _marker: PhantomData, +} + +/// Only MPSDataTypeFloat16 and MPSDataTypeFloat32 are supported for the result matrix. +impl Valid for MatMulResult {} +impl Valid for MatMulResult {} + +/// Helper struct used to indicate valid matrix multiplication types. +pub struct MatMulSpecification +where + A: MPSDataType, + B: MPSDataType, + C: MPSDataType, + MatMulInput: Valid, + MatMulInput: Valid, + MatMulResult: Valid, +{ + _marker: PhantomData<(A, B, C)>, +} + +/// Mixed input matrix multiplication is only for +impl Valid for MatMulSpecification {} + +/// All valid input types can produce a MPSDataTypeFloat32 result. +impl Valid for MatMulSpecification +where + T: MPSDataType, + MatMulInput: Valid, +{ +} + +/// These input types can produce a MPSDataTypeFloat16 result. +impl Valid for MatMulSpecification {} +impl Valid for MatMulSpecification {} +impl Valid for MatMulSpecification {} + /// See pub enum MPSMatrixDescriptor {} foreign_obj_type! { type CType = MPSMatrixDescriptor; - pub struct MatrixDescriptorObject; + pub struct MatrixDescriptor; type ParentType = NsObject; } -impl MatrixDescriptorObject { +impl MatrixDescriptor { fn init_single( rows: NSUInteger, columns: NSUInteger, @@ -695,8 +853,8 @@ impl MatrixDescriptorObject { rows: NSUInteger, columns: NSUInteger, matrices: NSUInteger, - matrix_bytes: NSUInteger, row_bytes: NSUInteger, + matrix_bytes: NSUInteger, data_type: u32, ) -> Self { unsafe { @@ -723,66 +881,12 @@ impl MatrixDescriptorObject { } } -#[derive(Debug)] -pub struct MatrixDescriptor { - pub object: MatrixDescriptorObject, - pub rows: NSUInteger, - pub columns: NSUInteger, - pub matrices: NSUInteger, - pub row_bytes: NSUInteger, - pub matrix_bytes: NSUInteger, - pub _marker: PhantomData, -} - -impl MatrixDescriptor { - pub fn single(rows: NSUInteger, columns: NSUInteger) -> Self { +impl From<&Matrix> for MatrixDescriptor { + fn from(matrix: &Matrix) -> Self { + let data_type = T::CODE; // The number of bytes between starting elements of consecutive rows. - let row_bytes = Self::row_bytes_for_columns(columns); - - let object = MatrixDescriptorObject::init_single(rows, columns, row_bytes, T::ENCODING); - MatrixDescriptor { - object, - rows, - columns, - row_bytes, - matrices: 1, - matrix_bytes: row_bytes, - _marker: PhantomData, - } - } - - pub fn multiple(rows: NSUInteger, columns: NSUInteger, matrices: NSUInteger) -> Self { - // The number of bytes between starting elements of consecutive rows. - let row_bytes = Self::row_bytes_for_columns(columns); - // The number of bytes between starting elements of consecutive matrices. - let matrix_bytes = row_bytes * rows; - - let object = MatrixDescriptorObject::init_multiple( - rows, - columns, - matrices, - matrix_bytes, - row_bytes, - T::ENCODING, - ); - MatrixDescriptor { - object, - rows, - columns, - matrices, - row_bytes, - matrix_bytes, - _marker: PhantomData, - } - } - - pub fn row_bytes_for_columns(columns: NSUInteger) -> NSUInteger { - MatrixDescriptorObject::row_bytes_for_columns(columns, T::ENCODING) - } - - pub fn required_buffer_size(&self) -> NSUInteger { - self.matrices * self.matrix_bytes + self.rows * self.row_bytes - //+ T::SIZE as NSUInteger * self.columns + let row_bytes = MatrixDescriptor::row_bytes_for_columns(matrix.columns, data_type); + Self::init_single(matrix.rows, matrix.columns, row_bytes, data_type) } } @@ -795,10 +899,18 @@ foreign_obj_type! { type ParentType = NsObject; } +/// Generic matrix for MPSDataTypes. +#[derive(Debug)] +pub struct Matrix { + pub entries: Vec, // row-major order + pub rows: NSUInteger, + pub columns: NSUInteger, +} + impl MatrixObject { fn init_with_device_descriptor( device: &DeviceRef, - descriptor: &MatrixDescriptorObjectRef, + descriptor: &MatrixDescriptorRef, ) -> Option { unsafe { let matrix: MatrixObject = msg_send![class!(MPSMatrix), alloc]; @@ -817,11 +929,8 @@ impl MatrixObject { fn init_with_buffer_descriptor( buffer: &BufferRef, - descriptor: &MatrixDescriptorObjectRef, + descriptor: &MatrixDescriptorRef, ) -> Option { - // assert!(buffer.length() >= descriptor.rowBytes() * descriptor.rows()); - // assert_eq!(buffer.length() % descriptor.rowBytes(), 0); - // assert_eq!(buffer.device(), descriptor.device()); unsafe { let matrix: MatrixObject = msg_send![class!(MPSMatrix), alloc]; let ptr: *mut Object = msg_send![ @@ -843,15 +952,15 @@ impl MatrixObjectRef { unsafe { msg_send![self, device] } } - pub fn rows(&self) -> u64 { + pub fn rows(&self) -> NSUInteger { unsafe { msg_send![self, rows] } } - pub fn columns(&self) -> u64 { + pub fn columns(&self) -> NSUInteger { unsafe { msg_send![self, columns] } } - pub fn row_bytes(&self) -> u64 { + pub fn row_bytes(&self) -> NSUInteger { unsafe { msg_send![self, rowBytes] } } @@ -863,43 +972,11 @@ impl MatrixObjectRef { unsafe { msg_send![self, data] } } - pub fn resource_size(&self) -> u64 { + pub fn resource_size(&self) -> NSUInteger { unsafe { msg_send![self, resourceSize] } } } -#[derive(Debug)] -pub struct Matrix { - pub object: MatrixObject, - pub entries: Vec, // row-major order - pub rows: u64, - pub columns: u64, -} - -impl Matrix { - pub fn init(device: &DeviceRef, descriptor: &MatrixDescriptor) -> Self { - let object = MatrixObject::init_with_device_descriptor(device, &descriptor.object).unwrap(); - let entries = vec![T::Type::default(); (&descriptor.rows * &descriptor.columns) as usize]; - Matrix { - object, - entries, - rows: descriptor.rows, - columns: descriptor.columns, - } - } - - pub fn init_with_buffer(buffer: &BufferRef, descriptor: &MatrixDescriptor) -> Self { - let object = MatrixObject::init_with_buffer_descriptor(buffer, &descriptor.object).unwrap(); - let entries = vec![T::Type::default(); (&descriptor.rows * &descriptor.columns) as usize]; - Matrix { - object, - entries, - rows: descriptor.rows, - columns: descriptor.columns, - } - } -} - /// A kernel for matrix multiplication. /// /// Computes the following operation: @@ -915,14 +992,13 @@ pub enum MPSMatrixMultiplication {} foreign_obj_type! { type CType = MPSMatrixMultiplication; - pub struct MatrixMultiplicationKernel; + pub struct MatrixMultiplication; type ParentType = Kernel; } -impl MatrixMultiplicationKernel { +impl MatrixMultiplication { pub fn from_device(device: &DeviceRef) -> Option { unsafe { - let kernel: MatrixMultiplicationKernel = - msg_send![class!(MPSMatrixMultiplication), alloc]; + let kernel: MatrixMultiplication = msg_send![class!(MPSMatrixMultiplication), alloc]; let ptr: *mut Object = msg_send![kernel.as_ref(), initWithDevice: device]; if ptr.is_null() { None @@ -939,12 +1015,15 @@ impl MatrixMultiplicationKernel { result_rows: NSUInteger, result_columns: NSUInteger, interior_columns: NSUInteger, - alpha: f32, - beta: f32, + alpha: f64, + beta: f64, ) -> Option { + assert!(result_rows > 0); + assert!(result_columns > 0); + assert!(interior_columns > 0); + unsafe { - let kernel: MatrixMultiplicationKernel = - msg_send![class!(MPSMatrixMultiplication), alloc]; + let kernel: MatrixMultiplication = msg_send![class!(MPSMatrixMultiplication), alloc]; let ptr: *mut Object = msg_send![ kernel.as_ref(), initWithDevice : device @@ -971,8 +1050,7 @@ impl MatrixMultiplicationKernel { interior_columns: NSUInteger, ) -> Option { unsafe { - let kernel: MatrixMultiplicationKernel = - msg_send![class!(MPSMatrixMultiplication), alloc]; + let kernel: MatrixMultiplication = msg_send![class!(MPSMatrixMultiplication), alloc]; let ptr: *mut Object = msg_send![ kernel.as_ref(), initWithDevice : device @@ -989,189 +1067,12 @@ impl MatrixMultiplicationKernel { } } -#[derive(Debug)] -struct MatrixMultiplication { - kernel: MatrixMultiplicationKernel, - transpose_left: bool, - transpose_right: bool, - result_rows: NSUInteger, - result_columns: NSUInteger, - interior_columns: NSUInteger, - alpha: f32, - beta: f32, -} - -/// Helper trait used indicates that a type constraint is valid. -trait Valid {} - -/// Helper struct used to indicate a valid matrix multiplication input type. -struct MatMulInput { - _marker: PhantomData, -} - -/// Input data type must be one of MPSDataTypeFloat32, MPSDataTypeFloat16, MPSDataTypeInt8, -/// or MPSDataTypeInt16 -impl Valid for MatMulInput {} -impl Valid for MatMulInput {} -impl Valid for MatMulInput {} -impl Valid for MatMulInput {} - -/// Helper struct used to indicate a valid matrix multiplication result type. -struct MatMulResult { - _marker: PhantomData, -} - -/// Only MPSDataTypeFloat16 and MPSDataTypeFloat32 are supported for the result matrix. -impl Valid for MatMulResult {} -impl Valid for MatMulResult {} - -/// Helper struct used to indicate valid matrix multiplication types. -struct MatMulSpecification -where - Left: MPSDataType, - MatMulInput: Valid, - Right: MPSDataType, - MatMulInput: Valid, - Result: MPSDataType, - MatMulResult: Valid, -{ - _marker: PhantomData<(Left, Right, Result)>, -} - -/// Mixed input matrix multiplication is only for -impl Valid for MatMulSpecification {} - -/// All valid input types can produce a MPSDataTypeFloat32 result. -impl Valid for MatMulSpecification -where - Input: MPSDataType, - MatMulInput: Valid, -{ -} - -/// These input types can produce a MPSDataTypeFloat16 result. -impl Valid for MatMulSpecification {} -impl Valid for MatMulSpecification {} -impl Valid for MatMulSpecification {} - -impl MatrixMultiplication { - pub fn init( - device: &DeviceRef, - transpose_left: bool, - transpose_right: bool, - result_rows: NSUInteger, - result_columns: NSUInteger, - interior_columns: NSUInteger, - alpha: f32, - beta: f32, - ) -> Option { - assert!(result_rows > 0); - assert!(result_columns > 0); - assert!(interior_columns > 0); - if let Some(kernel) = MatrixMultiplicationKernel::init( - device, - transpose_left, - transpose_right, - result_rows, - result_columns, - interior_columns, - alpha, - beta, - ) { - return Some(MatrixMultiplication { - kernel, - transpose_left, - transpose_right, - result_rows, - result_columns, - interior_columns, - alpha, - beta, - }); - } - None - } - - pub fn init_simple( - device: &DeviceRef, - result_rows: NSUInteger, - result_columns: NSUInteger, - interior_columns: NSUInteger, - ) -> Option { - assert!(result_rows > 0); - assert!(result_columns > 0); - assert!(interior_columns > 0); - if let Some(kernel) = MatrixMultiplicationKernel::init_simple( - device, - result_rows, - result_columns, - interior_columns, - ) { - return Some(MatrixMultiplication { - kernel, - transpose_left: false, - transpose_right: false, - result_rows, - result_columns, - interior_columns, - alpha: 1.0, - beta: 0.0, - }); - } - None - } +impl MatrixMultiplicationRef { /// Encode the kernel to the given command buffer. /// * `command_buffer` - The command buffer to encode the kernel to. /// * `left_matrix` - The left matrix to multiply. /// * `right_matrix` - The right matrix to multiply. /// * `result_matrix` - The matrix to store the result in. - pub fn encode_to_command_buffer( - &self, - command_buffer: &CommandBufferRef, - left_matrix: &Matrix, - right_matrix: &Matrix, - result_matrix: &Matrix, - ) where - T: MPSDataType, - U: MPSDataType, - V: MPSDataType, - MatMulInput: Valid, - MatMulInput: Valid, - MatMulResult: Valid, - MatMulSpecification: Valid, - { - // Certain constraints apply to the sizes of the matrices depending on the transposition - // operations and sizes requested at initialization time as well as the origins at the time - // this routine is called: - // - // The left input matrix must be large enough to hold an array of size resultRows x interiorColumns - // elements beginning at leftMatrixOrigin. - assert!(left_matrix.rows * left_matrix.columns >= self.result_rows * self.interior_columns); - // The right input matrix must be large enough to hold an array of size - // interiorColumns x resultColumns elements beginning at rightMatrixOrigin. - assert!( - right_matrix.rows * right_matrix.columns >= self.interior_columns * self.result_columns - ); - // The result matrix must be large enough to hold an array of size resultRows x resultColumns - // elements beginning at resultMatrixOrigin. - assert!( - result_matrix.rows * result_matrix.columns >= self.result_rows * self.result_columns - ); - - // Each matrix within the range specified by batchStart and batchSize, which also specifies - // a valid set of matrices within leftMatrix, rightMatrix, and resultMatrix, will - // be processed. - - self.kernel.encode_to_command_buffer( - command_buffer, - &left_matrix.object, - &right_matrix.object, - &result_matrix.object, - ); - } -} - -impl MatrixMultiplicationKernelRef { pub fn encode_to_command_buffer( &self, command_buffer: &CommandBufferRef, @@ -1202,33 +1103,31 @@ pub struct MatrixBuffer { impl MatrixBuffer { pub fn new( device: &DeviceRef, - descriptor: &MatrixDescriptor, + rows: NSUInteger, + columns: NSUInteger, + length: NSUInteger, options: MTLResourceOptions, ) -> Self { - let buffer = device.new_buffer(descriptor.required_buffer_size(), options); + let buffer = device.new_buffer(length, options); MatrixBuffer { buffer, - rows: descriptor.rows, - columns: descriptor.columns, + rows, + columns, _marker: PhantomData, } } pub fn new_with_data( device: &DeviceRef, - entries: &Vec, - descriptor: &MatrixDescriptor, + matrix: &Matrix, + length: NSUInteger, options: MTLResourceOptions, ) -> Self { - let buffer = device.new_buffer_with_data( - entries.as_ptr().cast(), - descriptor.required_buffer_size(), - options, - ); + let buffer = device.new_buffer_with_data(matrix.entries.as_ptr().cast(), length, options); MatrixBuffer { buffer, - rows: descriptor.rows, - columns: descriptor.columns, + rows: matrix.rows, + columns: matrix.columns, _marker: PhantomData, } } @@ -1241,114 +1140,139 @@ impl MatrixBuffer { } } -fn matmul( +pub fn read_buffer_to_vec(buffer: &BufferRef, len: usize) -> Vec { + Vec::from(unsafe { std::slice::from_raw_parts(buffer.contents() as *const T, len) }) +} + +pub fn matrix_multiplication( transpose_left: bool, transpose_right: bool, - left_descriptor: MatrixDescriptor, - left_entries: &Vec, - right_descriptor: MatrixDescriptor, - right_entries: &Vec, - alpha: f32, - beta: f32, -) -> Vec + left: &Matrix, + right: &Matrix, + alpha: f64, + beta: f64, +) -> Matrix where - T: MPSDataType, - U: MPSDataType, - V: MPSDataType, - MatMulInput: Valid, - MatMulInput: Valid, - MatMulResult: Valid, - MatMulSpecification: Valid, + A: MPSDataType, + B: MPSDataType, + C: MPSDataType, + MatMulInput: Valid, + MatMulInput: Valid, + MatMulResult: Valid, + MatMulSpecification: Valid, { - // For matrix multiplication, the number of columns in the first matrix must be equal to - // the number of rows in the second matrix. - // The result matrix has the number of rows of the first and the number of columns of the - // second matrix. - let result_descriptor = - MatrixDescriptor::::single(left_descriptor.rows, right_descriptor.columns); + validate_matrix_multiplication(transpose_left, transpose_right, left, right); let device = Device::system_default().expect("No device found"); - let matrix_multiplication = MatrixMultiplication::init_simple( + let command_queue = device.new_command_queue(); + + let M = left.rows; + let N = right.columns; + let K = left.columns; + + // Create descriptors for the matrices. + let left_row_bytes = MatrixDescriptor::row_bytes_for_columns(K, A::CODE); + let right_row_bytes = MatrixDescriptor::row_bytes_for_columns(N, B::CODE); + let result_row_bytes = MatrixDescriptor::row_bytes_for_columns(N, C::CODE); + + let left_descriptor = MatrixDescriptor::init_single(M, K, left_row_bytes, A::CODE); + let right_descriptor = MatrixDescriptor::init_single(K, N, right_row_bytes, B::CODE); + let result_descriptor = MatrixDescriptor::init_single(M, N, result_row_bytes, C::CODE); + + // Create buffers + let options = MTLResourceOptions::StorageModeShared; + let left_buffer = + device.new_buffer_with_data(left.entries.as_ptr().cast(), M * left_row_bytes, options); + let right_buffer = + device.new_buffer_with_data(right.entries.as_ptr().cast(), K * right_row_bytes, options); + let result_buffer = device.new_buffer(M * result_row_bytes, options); + + // Create matrix objects + let left_matrix = + MatrixObject::init_with_buffer_descriptor(&left_buffer, &left_descriptor).unwrap(); + let right_matrix = + MatrixObject::init_with_buffer_descriptor(&right_buffer, &right_descriptor).unwrap(); + let result_matrix = + MatrixObject::init_with_buffer_descriptor(&result_buffer, &result_descriptor).unwrap(); + + // Create kernel + let matrix_multiplication = MatrixMultiplication::init( &device, - left_descriptor.rows, - right_descriptor.columns, - left_descriptor.columns, + transpose_left, + transpose_right, + M, + N, + K, + alpha, + beta, ) .unwrap(); - let command_queue = device.new_command_queue(); let command_buffer = command_queue.new_command_buffer(); - let left_buffer = MatrixBuffer::new_with_data( - &device, - left_entries, - &left_descriptor, - MTLResourceOptions::StorageModeShared, - ); - let right_buffer = MatrixBuffer::new_with_data( - &device, - right_entries, - &right_descriptor, - MTLResourceOptions::StorageModeShared, - ); - let result_buffer = MatrixBuffer::new( - &device, - &result_descriptor, - MTLResourceOptions::StorageModeShared, - ); - - let left_matrix = Matrix::init_with_buffer(&left_buffer.buffer, &left_descriptor); - - let right_matrix = Matrix::init_with_buffer(&right_buffer.buffer, &right_descriptor); - - let result_matrix = Matrix::init_with_buffer(&result_buffer.buffer, &result_descriptor); - + // Encode kernel to command buffer matrix_multiplication.encode_to_command_buffer( &command_buffer, &left_matrix, &right_matrix, &result_matrix, ); + // Run multiplication command_buffer.commit(); command_buffer.wait_until_completed(); - let result = result_buffer.contents(); - result + // Get result from buffer + let entries = read_buffer_to_vec::(&result_buffer, (M * N) as usize); + Matrix { + entries, + rows: M, + columns: N, + } } -#[cfg(test)] -mod tests { - use super::*; - use rand::Rng; - - fn random_matrix(rows: usize, columns: usize) -> (Vec, NSUInteger, NSUInteger) { - let mut rng = rand::thread_rng(); - let mut entries = vec![0.0; rows * columns]; - for i in 0..rows { - for j in 0..columns { - entries[i * columns + j] = rng.gen(); - } - } - (entries, rows as NSUInteger, columns as NSUInteger) - } - - #[test] - fn test_matrix_multiplication() { - let (left_entries, l_rows, l_columns) = random_matrix(1024, 1024); - let (right_entries, r_rows, r_columns) = random_matrix(1024, 1024); - autoreleasepool(|| { - let result = matmul( - false, - false, - MatrixDescriptor::::single(l_rows, l_columns), - &left_entries, - MatrixDescriptor::::single(r_rows, r_columns), - &right_entries, - 1.0, - 0.0, - ); +fn validate_matrix_multiplication( + transpose_left: bool, + transpose_right: bool, + left: &Matrix, + right: &Matrix, +) where + A: MPSDataType, + B: MPSDataType, + C: MPSDataType, + MatMulInput: Valid, + MatMulInput: Valid, + MatMulResult: Valid, + MatMulSpecification: Valid, +{ - println!("{:?}", result.len()); - }); - } + // TODO ... + + // For matrix multiplication, the number of columns in the first matrix must be equal to + // the number of rows in the second matrix. + // The result matrix has the number of rows of the first and the number of columns of the + // second matrix. + // If only one matrix is transposed then the result matrix has the number of rows of the + // transposed matrix and the number of columns of the non-transposed matrix. + + // Certain constraints apply to the sizes of the matrices depending on the transposition + // operations and sizes requested at initialization time as well as the origins at the time + // this routine is called: + // + // The left input matrix must be large enough to hold an array of size resultRows x interiorColumns + // elements beginning at leftMatrixOrigin. + // assert!(left_matrix.rows * left_matrix.columns >= self.result_rows * self.interior_columns); + // The right input matrix must be large enough to hold an array of size + // interiorColumns x resultColumns elements beginning at rightMatrixOrigin. + // assert!( + // right_matrix.rows * right_matrix.columns >= self.interior_columns * self.result_columns + // ); + // The result matrix must be large enough to hold an array of size resultRows x resultColumns + // elements beginning at resultMatrixOrigin. + // assert!( + // result_matrix.rows * result_matrix.columns >= self.result_rows * self.result_columns + // ); + + // Each matrix within the range specified by batchStart and batchSize, which also specifies + // a valid set of matrices within leftMatrix, rightMatrix, and resultMatrix, will + // be processed. } From 530e2c11f7fe353b24cc1faaa9ed818ab3e5e3ff Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Sat, 21 Oct 2023 13:57:26 +0200 Subject: [PATCH 03/12] Use "apply" scheme for gemm function. Buffer reading bug needs fix. --- examples/mps/matrix-multiplication/main.rs | 31 ++++++++++-- src/mps.rs | 57 ++++++++++++---------- 2 files changed, 58 insertions(+), 30 deletions(-) diff --git a/examples/mps/matrix-multiplication/main.rs b/examples/mps/matrix-multiplication/main.rs index ed54c46..4e8d0c5 100644 --- a/examples/mps/matrix-multiplication/main.rs +++ b/examples/mps/matrix-multiplication/main.rs @@ -1,7 +1,7 @@ use metal::mps::*; use metal::*; -fn generate_matrix() -> Matrix +fn generate_matrix() -> Matrix where T: MPSDataType, MatMulInput: Valid, @@ -17,9 +17,9 @@ fn main() { type A = Float32; type B = Float32; type C = Float32; - const M: usize = 1; - const N: usize = 1; - const K: usize = 5; + const M: u64 = 2; + const N: u64 = 2; + const K: u64 = 2; let transpose_left = false; let transpose_right = false; @@ -32,6 +32,27 @@ fn main() { println!("{left:?}"); println!("{right:?}"); - let result = matrix_multiplication(transpose_left, transpose_right, &left, &right, alpha, beta); + let device = Device::system_default().expect("No device found"); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + + // Add matrix multiplication to command buffer and get result buffer + let result_buffer = apply_gemm( + &device, + command_buffer, + transpose_left, + transpose_right, + &left, + &right, + alpha, + beta, + ); + + // Run multiplication + command_buffer.commit(); + command_buffer.wait_until_completed(); + + // Read result buffer + let result = result_buffer.contents(); println!("{result:?}"); } diff --git a/src/mps.rs b/src/mps.rs index c0c8e00..b1f70f1 100644 --- a/src/mps.rs +++ b/src/mps.rs @@ -1138,20 +1138,25 @@ impl MatrixBuffer { unsafe { std::slice::from_raw_parts(contents, (self.rows * self.columns) as usize) }; sl.to_vec() } + pub fn read_to_vec(&self) -> Vec { + read_buffer_to_vec(&self.buffer, (self.rows * self.columns) as usize) + } } pub fn read_buffer_to_vec(buffer: &BufferRef, len: usize) -> Vec { Vec::from(unsafe { std::slice::from_raw_parts(buffer.contents() as *const T, len) }) } -pub fn matrix_multiplication( +pub fn apply_gemm( + device: &DeviceRef, + command_buffer: &CommandBufferRef, transpose_left: bool, transpose_right: bool, left: &Matrix, right: &Matrix, alpha: f64, beta: f64, -) -> Matrix +) -> MatrixBuffer where A: MPSDataType, B: MPSDataType, @@ -1161,14 +1166,23 @@ where MatMulResult: Valid, MatMulSpecification: Valid, { - validate_matrix_multiplication(transpose_left, transpose_right, left, right); - - let device = Device::system_default().expect("No device found"); - let command_queue = device.new_command_queue(); + let M = if transpose_left { + left.columns + } else { + left.rows + }; + let N = if transpose_right { + right.rows + } else { + right.columns + }; + let K = if transpose_left { + left.rows + } else { + left.columns + }; - let M = left.rows; - let N = right.columns; - let K = left.columns; + validate_matrix_multiplication(left, right, M, N, K); // Create descriptors for the matrices. let left_row_bytes = MatrixDescriptor::row_bytes_for_columns(K, A::CODE); @@ -1185,7 +1199,8 @@ where device.new_buffer_with_data(left.entries.as_ptr().cast(), M * left_row_bytes, options); let right_buffer = device.new_buffer_with_data(right.entries.as_ptr().cast(), K * right_row_bytes, options); - let result_buffer = device.new_buffer(M * result_row_bytes, options); + + let result_buffer = MatrixBuffer::new(device, M, N, M * result_row_bytes, options); // Create matrix objects let left_matrix = @@ -1193,7 +1208,8 @@ where let right_matrix = MatrixObject::init_with_buffer_descriptor(&right_buffer, &right_descriptor).unwrap(); let result_matrix = - MatrixObject::init_with_buffer_descriptor(&result_buffer, &result_descriptor).unwrap(); + MatrixObject::init_with_buffer_descriptor(&result_buffer.buffer, &result_descriptor) + .unwrap(); // Create kernel let matrix_multiplication = MatrixMultiplication::init( @@ -1208,8 +1224,6 @@ where ) .unwrap(); - let command_buffer = command_queue.new_command_buffer(); - // Encode kernel to command buffer matrix_multiplication.encode_to_command_buffer( &command_buffer, @@ -1217,24 +1231,17 @@ where &right_matrix, &result_matrix, ); - // Run multiplication - command_buffer.commit(); - command_buffer.wait_until_completed(); - // Get result from buffer - let entries = read_buffer_to_vec::(&result_buffer, (M * N) as usize); - Matrix { - entries, - rows: M, - columns: N, - } + // Return result buffer + result_buffer } fn validate_matrix_multiplication( - transpose_left: bool, - transpose_right: bool, left: &Matrix, right: &Matrix, + M: NSUInteger, + N: NSUInteger, + K: NSUInteger, ) where A: MPSDataType, B: MPSDataType, From 5cd97c31844055961adf068ae26c41644a387fb8 Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Wed, 25 Oct 2023 14:46:44 +0200 Subject: [PATCH 04/12] Matrix multiplication example is now a performance test. Seems to work except GFLOPS calculation --- examples/mps/matrix-multiplication/main.rs | 81 +++---- src/buffer.rs | 10 + src/mps.rs | 236 ++++++++++----------- 3 files changed, 170 insertions(+), 157 deletions(-) diff --git a/examples/mps/matrix-multiplication/main.rs b/examples/mps/matrix-multiplication/main.rs index 4e8d0c5..d833b65 100644 --- a/examples/mps/matrix-multiplication/main.rs +++ b/examples/mps/matrix-multiplication/main.rs @@ -1,58 +1,67 @@ use metal::mps::*; use metal::*; +use rand::{thread_rng, Rng}; fn generate_matrix() -> Matrix where T: MPSDataType, - MatMulInput: Valid, + GEMMInput: Valid, { - Matrix { - entries: (1..=ROWS * COLS).map(|i| T::from_f64(i as f64)).collect(), - rows: ROWS as NSUInteger, - columns: COLS as NSUInteger, - } + let mut rng = thread_rng(); + Matrix::new( + (0..ROWS * COLS).map(|_| T::from_f64(rng.gen())).collect(), + ROWS as NSUInteger, + COLS as NSUInteger, + ) } fn main() { - type A = Float32; - type B = Float32; - type C = Float32; - const M: u64 = 2; - const N: u64 = 2; - const K: u64 = 2; + const M: u64 = 4096; + const N: u64 = 4096; + const K: u64 = 4096; + const RUNS: u64 = 100; let transpose_left = false; let transpose_right = false; let alpha = 1.0; let beta = 0.0; - let left = generate_matrix::(); - let right = generate_matrix::(); - - println!("{left:?}"); - println!("{right:?}"); + // Generate random matrices + let left = generate_matrix::(); + let right = generate_matrix::(); + // Setup let device = Device::system_default().expect("No device found"); let command_queue = device.new_command_queue(); - let command_buffer = command_queue.new_command_buffer(); + let mut total_time = std::time::Duration::new(0, 0); - // Add matrix multiplication to command buffer and get result buffer - let result_buffer = apply_gemm( - &device, - command_buffer, - transpose_left, - transpose_right, - &left, - &right, - alpha, - beta, - ); - - // Run multiplication - command_buffer.commit(); - command_buffer.wait_until_completed(); + for _ in 0..RUNS { + let command_buffer = command_queue.new_command_buffer(); + let start = std::time::Instant::now(); + let _ = encode_gemm( + &device, + command_buffer, + transpose_left, + transpose_right, + &left, + &right, + alpha, + beta, + ); + command_buffer.commit(); + command_buffer.wait_until_completed(); + let time = std::time::Instant::now() - start; + total_time += time; + } - // Read result buffer - let result = result_buffer.contents(); - println!("{result:?}"); + // Calculate GFLOPS + // C <- alpha * AB + beta * C + // Operations = M * N * (K+2) + M * N * K + let ops_count = M * N * (2 * K + 2); + let ops_count = (ops_count * RUNS) as f64; + let gflops = ops_count / (total_time.as_secs_f64() * 1000e+3f64); + // TODO: Something is wrong here hehe + println!("GFLOPS: {}", gflops); + println!("Total time: {:?}", total_time); + println!("Avg time: {:?}", total_time / RUNS as u32); } diff --git a/src/buffer.rs b/src/buffer.rs index 8f3108a..6496e97 100644 --- a/src/buffer.rs +++ b/src/buffer.rs @@ -68,4 +68,14 @@ impl BufferRef { pub fn gpu_address(&self) -> u64 { unsafe { msg_send![self, gpuAddress] } } + + pub fn read_to_slice(&self, len: usize) -> &[T] { + let contents_ptr = self.contents() as *const T; + assert!(!contents_ptr.is_null()); + unsafe { std::slice::from_raw_parts(contents_ptr, len) } + } + + pub fn read_to_vec(&self, len: usize) -> Vec { + self.read_to_slice(len).to_vec() + } } diff --git a/src/mps.rs b/src/mps.rs index b1f70f1..86fb696 100644 --- a/src/mps.rs +++ b/src/mps.rs @@ -8,7 +8,7 @@ use super::*; use half::{bf16, f16}; use objc::runtime::{BOOL, YES}; -use std::fmt::Debug; +use std::fmt::{Debug, Display, Formatter}; use std::hash::Hash; #[cfg_attr( @@ -324,7 +324,7 @@ impl PolygonAccelerationStructureRef { } pub fn set_index_type(&self, _data_type: T) { - unsafe { msg_send![self, setIndexType: T::CODE] } + unsafe { msg_send![self, setIndexType: T::TYPE_ID] } } pub fn set_mask_buffer(&self, buffer: Option<&BufferRef>) { @@ -554,10 +554,10 @@ pub struct MPSIntersectionDistancePrimitiveIndexCoordinates { /// See . pub trait MPSDataType: Clone + Copy + PartialEq + Eq + Debug + Hash { type Type: Default + Clone + Copy + PartialEq + Debug + Sized; - const CODE: u32; + const TYPE_ID: NSUInteger; /// See . - const SIZE: u32 = ((Self::CODE & 0xFFFF) >> 3); + const SIZE: NSUInteger = ((Self::TYPE_ID & 0xFFFF) >> 3) as NSUInteger; fn from_f64(v: f64) -> Self::Type; @@ -565,24 +565,24 @@ pub trait MPSDataType: Clone + Copy + PartialEq + Eq + Debug + Hash { } /// A common bit for all floating point data types. Zero for integer types -const MPS_FLOATBIT_ENCODING: u32 = 0x10000000; +const MPS_FLOATBIT_ENCODING: NSUInteger = 0x10000000; /// A common bit for all complex point data types. Zero for integer types -const MPS_COMPLEXBIT_ENCODING: u32 = MPS_FLOATBIT_ENCODING | 0x01000000; +const MPS_COMPLEXBIT_ENCODING: NSUInteger = MPS_FLOATBIT_ENCODING | 0x01000000; /// A common bit for all signed data types -const MPS_SIGNEDBIT_ENCODING: u32 = 0x20000000; +const MPS_SIGNEDBIT_ENCODING: NSUInteger = 0x20000000; /// A common bit for all alternate encoding data types -const MPS_ALTERNATE_ENCODING: u32 = 0x80000000; +const MPS_ALTERNATE_ENCODING: NSUInteger = 0x80000000; /// A common bit for all normalized data types. /// If set, the value of the shall be interpreted as value / UNORM_TYPE_MAX /// Normalized values have range [0, 1.0] if unsigned and [-1,1] if signed. /// SNORM_TYPE_MIN is interpreted as SNORM_TYPE_MIN+1 per standard Metal rules. -const MPS_NORMALIZEDBIT_ENCODING: u32 = 0x40000000; +const MPS_NORMALIZEDBIT_ENCODING: NSUInteger = 0x40000000; macro_rules! mps_datatype_impl { - ($dt:ident, $dt_ty:ty, $code:expr, $from_f64:expr, $to_f64:expr) => { + ($dt:ident, $dt_ty:ty, $type_id:expr, $from_f64:expr, $to_f64:expr) => { impl MPSDataType for $dt { type Type = $dt_ty; - const CODE: u32 = $code; + const TYPE_ID: NSUInteger = $type_id; fn from_f64(v: f64) -> Self::Type { $from_f64(v) @@ -595,18 +595,18 @@ macro_rules! mps_datatype_impl { }; } macro_rules! mps_datatype { - ($dt:ident, $dt_ty:ty, $code:expr, $from_f64:expr, $to_f64:expr, $comment:expr) => { + ($dt:ident, $dt_ty:ty, $type_id:expr, $from_f64:expr, $to_f64:expr, $comment:expr) => { #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] #[doc=$comment] pub struct $dt; - mps_datatype_impl!($dt, $dt_ty, $code, $from_f64, $to_f64); + mps_datatype_impl!($dt, $dt_ty, $type_id, $from_f64, $to_f64); }; - ($dt:ident, $dt_ty:ty, $code:expr, $from_f64:expr, $to_f64:expr) => { + ($dt:ident, $dt_ty:ty, $type_id:expr, $from_f64:expr, $to_f64:expr) => { #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub struct $dt; - mps_datatype_impl!($dt, $dt_ty, $code, $from_f64, $to_f64); + mps_datatype_impl!($dt, $dt_ty, $type_id, $from_f64, $to_f64); }; } mps_datatype!(Invalid, (), 0, |_: f64| (), |_: ()| 0.0); @@ -773,54 +773,54 @@ mps_datatype!( pub trait Valid {} /// Helper struct used to indicate a valid matrix multiplication input type. -pub struct MatMulInput { +pub struct GEMMInput { _marker: PhantomData, } /// Input data type must be one of MPSDataTypeFloat32, MPSDataTypeFloat16, MPSDataTypeInt8, /// or MPSDataTypeInt16 -impl Valid for MatMulInput {} -impl Valid for MatMulInput {} -impl Valid for MatMulInput {} -impl Valid for MatMulInput {} +impl Valid for GEMMInput {} +impl Valid for GEMMInput {} +impl Valid for GEMMInput {} +impl Valid for GEMMInput {} /// Helper struct used to indicate a valid matrix multiplication result type. -pub struct MatMulResult { +pub struct GEMMResult { _marker: PhantomData, } /// Only MPSDataTypeFloat16 and MPSDataTypeFloat32 are supported for the result matrix. -impl Valid for MatMulResult {} -impl Valid for MatMulResult {} +impl Valid for GEMMResult {} +impl Valid for GEMMResult {} /// Helper struct used to indicate valid matrix multiplication types. -pub struct MatMulSpecification +pub struct GEMMSpecification where A: MPSDataType, B: MPSDataType, C: MPSDataType, - MatMulInput: Valid, - MatMulInput: Valid, - MatMulResult: Valid, + GEMMInput: Valid, + GEMMInput: Valid, + GEMMResult: Valid, { _marker: PhantomData<(A, B, C)>, } /// Mixed input matrix multiplication is only for -impl Valid for MatMulSpecification {} +impl Valid for GEMMSpecification {} /// All valid input types can produce a MPSDataTypeFloat32 result. -impl Valid for MatMulSpecification +impl Valid for GEMMSpecification where T: MPSDataType, - MatMulInput: Valid, + GEMMInput: Valid, { } /// These input types can produce a MPSDataTypeFloat16 result. -impl Valid for MatMulSpecification {} -impl Valid for MatMulSpecification {} -impl Valid for MatMulSpecification {} +impl Valid for GEMMSpecification {} +impl Valid for GEMMSpecification {} +impl Valid for GEMMSpecification {} /// See pub enum MPSMatrixDescriptor {} @@ -836,7 +836,7 @@ impl MatrixDescriptor { rows: NSUInteger, columns: NSUInteger, row_bytes: NSUInteger, - data_type: u32, + data_type: NSUInteger, ) -> Self { unsafe { msg_send![ @@ -870,7 +870,7 @@ impl MatrixDescriptor { } } - fn row_bytes_for_columns(columns: NSUInteger, data_type: u32) -> NSUInteger { + fn row_bytes_for_columns(columns: NSUInteger, data_type: NSUInteger) -> NSUInteger { unsafe { msg_send![ class!(MPSMatrixDescriptor), @@ -883,7 +883,7 @@ impl MatrixDescriptor { impl From<&Matrix> for MatrixDescriptor { fn from(matrix: &Matrix) -> Self { - let data_type = T::CODE; + let data_type = T::TYPE_ID; // The number of bytes between starting elements of consecutive rows. let row_bytes = MatrixDescriptor::row_bytes_for_columns(matrix.columns, data_type); Self::init_single(matrix.rows, matrix.columns, row_bytes, data_type) @@ -902,9 +902,55 @@ foreign_obj_type! { /// Generic matrix for MPSDataTypes. #[derive(Debug)] pub struct Matrix { - pub entries: Vec, // row-major order - pub rows: NSUInteger, - pub columns: NSUInteger, + entries: Vec, // row-major order + rows: NSUInteger, + columns: NSUInteger, +} + +impl Matrix { + pub fn new(entries: Vec, rows: NSUInteger, columns: NSUInteger) -> Self { + assert_eq!(entries.len(), rows as usize * columns as usize); + Self { + entries, + rows, + columns, + } + } + pub fn entries(&self) -> Vec { + self.entries.clone() + } +} + +impl From> for Matrix { + fn from(buffer: MatrixBuffer) -> Self { + Self::new(buffer.contents(), buffer.rows, buffer.columns) + } +} + +impl Display for Matrix { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + assert_eq!( + self.entries.len(), + self.rows as usize * self.columns as usize + ); + let mut col = 0; + for i in 0..(self.rows * self.columns) as usize { + if col == 0 { + write!(f, "|")?; + } + + write!(f, "{:?}", self.entries.get(i).ok_or(std::fmt::Error)?)?; + + if col < self.columns as usize - 1 { + write!(f, ", ")?; + col += 1; + } else { + writeln!(f, "|")?; + col = 0; + } + } + Ok(()) + } } impl MatrixObject { @@ -1092,9 +1138,8 @@ impl MatrixMultiplicationRef { } } -#[derive(Debug)] pub struct MatrixBuffer { - buffer: Buffer, + pub buffer: Buffer, rows: NSUInteger, columns: NSUInteger, _marker: PhantomData, @@ -1117,37 +1162,16 @@ impl MatrixBuffer { } } - pub fn new_with_data( - device: &DeviceRef, - matrix: &Matrix, - length: NSUInteger, - options: MTLResourceOptions, - ) -> Self { - let buffer = device.new_buffer_with_data(matrix.entries.as_ptr().cast(), length, options); - MatrixBuffer { - buffer, - rows: matrix.rows, - columns: matrix.columns, - _marker: PhantomData, - } + pub fn count(&self) -> usize { + (self.rows * self.columns) as usize } pub fn contents(&self) -> Vec { - let contents = self.buffer.contents() as *const T::Type; - let sl: &[T::Type] = - unsafe { std::slice::from_raw_parts(contents, (self.rows * self.columns) as usize) }; - sl.to_vec() - } - pub fn read_to_vec(&self) -> Vec { - read_buffer_to_vec(&self.buffer, (self.rows * self.columns) as usize) + self.buffer.read_to_vec(self.count()) } } -pub fn read_buffer_to_vec(buffer: &BufferRef, len: usize) -> Vec { - Vec::from(unsafe { std::slice::from_raw_parts(buffer.contents() as *const T, len) }) -} - -pub fn apply_gemm( +pub fn encode_gemm( device: &DeviceRef, command_buffer: &CommandBufferRef, transpose_left: bool, @@ -1161,10 +1185,10 @@ where A: MPSDataType, B: MPSDataType, C: MPSDataType, - MatMulInput: Valid, - MatMulInput: Valid, - MatMulResult: Valid, - MatMulSpecification: Valid, + GEMMInput: Valid, + GEMMInput: Valid, + GEMMResult: Valid, + GEMMSpecification: Valid, { let M = if transpose_left { left.columns @@ -1182,16 +1206,12 @@ where left.columns }; - validate_matrix_multiplication(left, right, M, N, K); + validate_shapes(M, N, K); // Create descriptors for the matrices. - let left_row_bytes = MatrixDescriptor::row_bytes_for_columns(K, A::CODE); - let right_row_bytes = MatrixDescriptor::row_bytes_for_columns(N, B::CODE); - let result_row_bytes = MatrixDescriptor::row_bytes_for_columns(N, C::CODE); - - let left_descriptor = MatrixDescriptor::init_single(M, K, left_row_bytes, A::CODE); - let right_descriptor = MatrixDescriptor::init_single(K, N, right_row_bytes, B::CODE); - let result_descriptor = MatrixDescriptor::init_single(M, N, result_row_bytes, C::CODE); + let left_row_bytes = MatrixDescriptor::row_bytes_for_columns(K, A::TYPE_ID); + let right_row_bytes = MatrixDescriptor::row_bytes_for_columns(N, B::TYPE_ID); + let result_row_bytes = MatrixDescriptor::row_bytes_for_columns(N, C::TYPE_ID); // Create buffers let options = MTLResourceOptions::StorageModeShared; @@ -1202,9 +1222,15 @@ where let result_buffer = MatrixBuffer::new(device, M, N, M * result_row_bytes, options); + // Create descriptors + let left_descriptor = MatrixDescriptor::init_single(M, K, K * A::SIZE, A::TYPE_ID); + let right_descriptor = MatrixDescriptor::init_single(K, N, N * B::SIZE, B::TYPE_ID); + let result_descriptor = MatrixDescriptor::init_single(M, N, N * C::SIZE, C::TYPE_ID); + // Create matrix objects let left_matrix = MatrixObject::init_with_buffer_descriptor(&left_buffer, &left_descriptor).unwrap(); + let right_matrix = MatrixObject::init_with_buffer_descriptor(&right_buffer, &right_descriptor).unwrap(); let result_matrix = @@ -1236,50 +1262,18 @@ where result_buffer } -fn validate_matrix_multiplication( - left: &Matrix, - right: &Matrix, - M: NSUInteger, - N: NSUInteger, - K: NSUInteger, -) where - A: MPSDataType, - B: MPSDataType, - C: MPSDataType, - MatMulInput: Valid, - MatMulInput: Valid, - MatMulResult: Valid, - MatMulSpecification: Valid, -{ - - // TODO ... - - // For matrix multiplication, the number of columns in the first matrix must be equal to - // the number of rows in the second matrix. - // The result matrix has the number of rows of the first and the number of columns of the - // second matrix. - // If only one matrix is transposed then the result matrix has the number of rows of the - // transposed matrix and the number of columns of the non-transposed matrix. - +fn validate_shapes(M: NSUInteger, N: NSUInteger, K: NSUInteger) { // Certain constraints apply to the sizes of the matrices depending on the transposition // operations and sizes requested at initialization time as well as the origins at the time // this routine is called: - // - // The left input matrix must be large enough to hold an array of size resultRows x interiorColumns - // elements beginning at leftMatrixOrigin. - // assert!(left_matrix.rows * left_matrix.columns >= self.result_rows * self.interior_columns); - // The right input matrix must be large enough to hold an array of size - // interiorColumns x resultColumns elements beginning at rightMatrixOrigin. - // assert!( - // right_matrix.rows * right_matrix.columns >= self.interior_columns * self.result_columns - // ); - // The result matrix must be large enough to hold an array of size resultRows x resultColumns - // elements beginning at resultMatrixOrigin. - // assert!( - // result_matrix.rows * result_matrix.columns >= self.result_rows * self.result_columns - // ); - - // Each matrix within the range specified by batchStart and batchSize, which also specifies - // a valid set of matrices within leftMatrix, rightMatrix, and resultMatrix, will - // be processed. + assert!(M > 0); + assert!(N > 0); + assert!(K > 0); + // Left column size must equal right row size. + assert_eq!(K, N); + + // The left matrix must be larger or equal to result rows * interior columns + assert!(M * K >= M * N); + // The right matrix must be larger or equal to result columns * interior columns + assert!(K * N >= M * N); } From aed9b290ff75e9c7c825be9ffedab4c93fb769d1 Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Thu, 26 Oct 2023 13:59:08 +0200 Subject: [PATCH 05/12] Add gemm correctness test and improve the performance test --- examples/mps/matrix-multiplication/main.rs | 227 +++++++++++++++++---- src/mps.rs | 25 ++- 2 files changed, 205 insertions(+), 47 deletions(-) diff --git a/examples/mps/matrix-multiplication/main.rs b/examples/mps/matrix-multiplication/main.rs index d833b65..c4792ec 100644 --- a/examples/mps/matrix-multiplication/main.rs +++ b/examples/mps/matrix-multiplication/main.rs @@ -1,6 +1,126 @@ use metal::mps::*; use metal::*; use rand::{thread_rng, Rng}; +use std::io::Write; +use std::ops::{AddAssign, Mul}; +use std::{array, io}; + +fn main() { + correctness(); + performance(); +} + +fn correctness() { + // First verify the correctness of the naive solution + let a = Matrix::new([1, 2, 6, 24, 120, 720], 3, 2); + let b = Matrix::new([1, 2, 3, 5, 8, 13], 2, 3); + let result = matrix_mul::(a, b); + assert_eq!( + result.entries(), + &[11, 18, 29, 126, 204, 330, 3720, 6000, 9720] + ); + + const M: u64 = 100; + const N: u64 = 100; + const K: u64 = 100; + const ITERATIONS: usize = 50; + + let device = Device::system_default().expect("No device found"); + let command_queue = device.new_command_queue(); + + println!("Correctness: "); + for i in 0..ITERATIONS { + progress_bar(i, ITERATIONS); + + let left = generate_matrix::(); + let right = generate_matrix::(); + + let command_buffer = command_queue.new_command_buffer(); + let result = encode_gemm( + &device, + command_buffer, + false, + false, + &left, + &right, + 1.0, + 0.0, + ); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + let expected = matrix_mul(left, right); + approx_eq(result.contents(), expected.entries().to_vec()); + } + + println!(" ✅\n"); +} + +fn performance() { + const M: u64 = 4096; + const N: u64 = 4096; + const K: u64 = 4096; + + const ITERATIONS: usize = 50; + + println!("Performance: "); + println!("Generating input matrices: (f32 {M}x{K} and f16 {K}x{N})"); + // Generate random matrices + let left = generate_matrix::(); + let right = generate_matrix::(); + + // Setup + let device = Device::system_default().expect("No device found"); + let command_queue = device.new_command_queue(); + + let cases = [ + (false, false, 1.0, 0.0), + (true, false, 1.0, 0.0), + (false, true, 1.0, 0.0), + (false, false, 0.5, 0.0), + (false, false, 1.0, 0.5), + ]; + for (t_left, t_right, alpha, beta) in cases { + println!("Running with transpose left: {t_left}, transpose right: {t_right}, alpha: {alpha}, beta: {beta}"); + let mut flops: Vec = vec![]; + + let mut total_time = std::time::Duration::new(0, 0); + for i in 0..ITERATIONS { + progress_bar(i, ITERATIONS); + + let start = std::time::Instant::now(); + let command_buffer = command_queue.new_command_buffer(); + let _ = encode_gemm( + &device, + command_buffer, + t_left, + t_right, + &left, + &right, + alpha, + beta, + ); + command_buffer.commit(); + command_buffer.wait_until_completed(); + + let time = std::time::Instant::now() - start; + + total_time += time; + + // Calculate GFLOPS + // C <- alpha * AB + beta * C + // Operations = 2(M * N * K) + flops.push((M * N * (2 * K + 2)) as f64 / (time.as_secs_f64() * 1e+9f64)); + } + println!(" ✅"); + + let avg_gflops = flops.iter().sum::() / flops.len() as f64; + println!("Avg GFLOPS: {}", avg_gflops); + println!("Total time: {:#?}", total_time); + println!("Avg time: {:#?}", total_time / ITERATIONS as u32); + println!() + } +} fn generate_matrix() -> Matrix where @@ -9,59 +129,78 @@ where { let mut rng = thread_rng(); Matrix::new( - (0..ROWS * COLS).map(|_| T::from_f64(rng.gen())).collect(), + (0..ROWS * COLS).map(|_| T::from_f64(rng.gen())), ROWS as NSUInteger, COLS as NSUInteger, ) } -fn main() { - const M: u64 = 4096; - const N: u64 = 4096; - const K: u64 = 4096; - const RUNS: u64 = 100; +// Naive matrix multiplication for testing +fn matrix_mul(a: Matrix, b: Matrix) -> Matrix +where + T::Type: AddAssign + Mul + Copy, +{ + assert_eq!(a.columns(), b.rows()); + let sum_count = a.columns() as usize; + let rows = a.rows() as usize; + let columns = b.columns() as usize; + let size = rows * columns; - let transpose_left = false; - let transpose_right = false; - let alpha = 1.0; - let beta = 0.0; + let mut entries = Vec::with_capacity(size); - // Generate random matrices - let left = generate_matrix::(); - let right = generate_matrix::(); + for idx in 0..size { + let i = idx / rows; + let j = idx % columns; - // Setup - let device = Device::system_default().expect("No device found"); - let command_queue = device.new_command_queue(); - let mut total_time = std::time::Duration::new(0, 0); + let mut sum = T::from_f64(0.0); + for di in 0..sum_count { + sum += a.entry(i, di) * b.entry(di, j); + } + entries.push(sum); + } - for _ in 0..RUNS { - let command_buffer = command_queue.new_command_buffer(); - let start = std::time::Instant::now(); - let _ = encode_gemm( - &device, - command_buffer, - transpose_left, - transpose_right, - &left, - &right, - alpha, - beta, - ); - command_buffer.commit(); - command_buffer.wait_until_completed(); - let time = std::time::Instant::now() - start; - total_time += time; + Matrix::new(entries, a.rows(), b.columns()) +} + +fn euclidean_distance(a: Vec, b: Vec) -> f64 +where + T: Into + Clone + Copy, +{ + assert_eq!(a.len(), b.len(), "Lengths not equal"); + + let mut sum = 0.0; + + for i in 0..a.len() { + sum += (a[i].into() - b[i].into()).powi(2); } - // Calculate GFLOPS - // C <- alpha * AB + beta * C - // Operations = M * N * (K+2) + M * N * K - let ops_count = M * N * (2 * K + 2); - let ops_count = (ops_count * RUNS) as f64; - let gflops = ops_count / (total_time.as_secs_f64() * 1000e+3f64); - // TODO: Something is wrong here hehe - println!("GFLOPS: {}", gflops); - println!("Total time: {:?}", total_time); - println!("Avg time: {:?}", total_time / RUNS as u32); + sum.sqrt() +} + +fn approx_eq(a: Vec, b: Vec) +where + T: Into + Clone + Copy, +{ + assert_eq!(a.len(), b.len(), "Lengths not equal"); + + let avg_magnitude = 0.004f64; + let avg_deviation = (a.len() as f64).sqrt(); + let tolerance = avg_magnitude.max(avg_deviation * 3e-7); + + let distance = euclidean_distance(a, b); + assert!( + distance < tolerance, + "Distance not less than tolerance: {} < {} ", + distance, + tolerance + ); +} + +fn progress_bar(i: usize, len: usize) { + print!("\r"); + print!("["); + print!("{}", "=".repeat(i)); + print!("{}", " ".repeat(len - i - 1)); + print!("]"); + io::stdout().flush().unwrap(); } diff --git a/src/mps.rs b/src/mps.rs index 86fb696..b984284 100644 --- a/src/mps.rs +++ b/src/mps.rs @@ -908,7 +908,12 @@ pub struct Matrix { } impl Matrix { - pub fn new(entries: Vec, rows: NSUInteger, columns: NSUInteger) -> Self { + pub fn new>( + entries: E, + rows: NSUInteger, + columns: NSUInteger, + ) -> Matrix { + let entries: Vec = entries.into_iter().collect(); assert_eq!(entries.len(), rows as usize * columns as usize); Self { entries, @@ -916,8 +921,22 @@ impl Matrix { columns, } } - pub fn entries(&self) -> Vec { - self.entries.clone() + pub fn entries(&self) -> &[T::Type] { + &self.entries + } + + pub fn entry(&self, row: usize, column: usize) -> T::Type { + assert!(row < self.rows as usize); + assert!(column < self.columns as usize); + self.entries[row * self.columns as usize + column] + } + + pub fn rows(&self) -> NSUInteger { + self.rows + } + + pub fn columns(&self) -> NSUInteger { + self.columns } } From 2c1490c169297e2282fb6136144ba8d68bcd865a Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Mon, 30 Oct 2023 10:33:18 +0100 Subject: [PATCH 06/12] small improvements --- examples/mps/matrix-multiplication/main.rs | 10 +++-- src/mps.rs | 50 +++++++++++++--------- 2 files changed, 35 insertions(+), 25 deletions(-) diff --git a/examples/mps/matrix-multiplication/main.rs b/examples/mps/matrix-multiplication/main.rs index c4792ec..07e839a 100644 --- a/examples/mps/matrix-multiplication/main.rs +++ b/examples/mps/matrix-multiplication/main.rs @@ -1,9 +1,11 @@ -use metal::mps::*; -use metal::*; -use rand::{thread_rng, Rng}; +use std::io; use std::io::Write; use std::ops::{AddAssign, Mul}; -use std::{array, io}; + +use rand::{thread_rng, Rng}; + +use metal::mps::*; +use metal::*; fn main() { correctness(); diff --git a/src/mps.rs b/src/mps.rs index b984284..672bc1f 100644 --- a/src/mps.rs +++ b/src/mps.rs @@ -5,12 +5,14 @@ // http://opensource.org/licenses/MIT>, at your option. This file may not be // copied, modified, or distributed except according to those terms. -use super::*; -use half::{bf16, f16}; -use objc::runtime::{BOOL, YES}; use std::fmt::{Debug, Display, Formatter}; use std::hash::Hash; +use half::{bf16, f16}; +use objc::runtime::{BOOL, YES}; + +use super::*; + #[cfg_attr( feature = "link", link(name = "MetalPerformanceShaders", kind = "framework") @@ -780,8 +782,11 @@ pub struct GEMMInput { /// Input data type must be one of MPSDataTypeFloat32, MPSDataTypeFloat16, MPSDataTypeInt8, /// or MPSDataTypeInt16 impl Valid for GEMMInput {} + impl Valid for GEMMInput {} + impl Valid for GEMMInput {} + impl Valid for GEMMInput {} /// Helper struct used to indicate a valid matrix multiplication result type. @@ -791,6 +796,7 @@ pub struct GEMMResult { /// Only MPSDataTypeFloat16 and MPSDataTypeFloat32 are supported for the result matrix. impl Valid for GEMMResult {} + impl Valid for GEMMResult {} /// Helper struct used to indicate valid matrix multiplication types. @@ -819,7 +825,9 @@ where /// These input types can produce a MPSDataTypeFloat16 result. impl Valid for GEMMSpecification {} + impl Valid for GEMMSpecification {} + impl Valid for GEMMSpecification {} /// See @@ -902,7 +910,8 @@ foreign_obj_type! { /// Generic matrix for MPSDataTypes. #[derive(Debug)] pub struct Matrix { - entries: Vec, // row-major order + entries: Vec, + // row-major order rows: NSUInteger, columns: NSUInteger, } @@ -1158,9 +1167,11 @@ impl MatrixMultiplicationRef { } pub struct MatrixBuffer { - pub buffer: Buffer, + buffer: Buffer, rows: NSUInteger, columns: NSUInteger, + count: usize, + allocated_size: usize, _marker: PhantomData, } @@ -1177,16 +1188,18 @@ impl MatrixBuffer { buffer, rows, columns, + count: (rows * columns) as usize, + allocated_size: length as usize, _marker: PhantomData, } } pub fn count(&self) -> usize { - (self.rows * self.columns) as usize + self.count } pub fn contents(&self) -> Vec { - self.buffer.read_to_vec(self.count()) + self.buffer.read_to_vec(self.count) } } @@ -1209,23 +1222,18 @@ where GEMMResult: Valid, GEMMSpecification: Valid, { - let M = if transpose_left { - left.columns - } else { - left.rows - }; - let N = if transpose_right { - right.rows + let (M, K) = if transpose_left { + (left.columns, left.rows) } else { - right.columns + (left.rows, left.columns) }; - let K = if transpose_left { - left.rows + let (N, B_K) = if transpose_right { + (right.rows, right.columns) } else { - left.columns + (right.columns, right.rows) }; - validate_shapes(M, N, K); + validate_shapes(M, N, K, B_K); // Create descriptors for the matrices. let left_row_bytes = MatrixDescriptor::row_bytes_for_columns(K, A::TYPE_ID); @@ -1249,7 +1257,6 @@ where // Create matrix objects let left_matrix = MatrixObject::init_with_buffer_descriptor(&left_buffer, &left_descriptor).unwrap(); - let right_matrix = MatrixObject::init_with_buffer_descriptor(&right_buffer, &right_descriptor).unwrap(); let result_matrix = @@ -1281,13 +1288,14 @@ where result_buffer } -fn validate_shapes(M: NSUInteger, N: NSUInteger, K: NSUInteger) { +fn validate_shapes(M: NSUInteger, N: NSUInteger, K: NSUInteger, B_K: NSUInteger) { // Certain constraints apply to the sizes of the matrices depending on the transposition // operations and sizes requested at initialization time as well as the origins at the time // this routine is called: assert!(M > 0); assert!(N > 0); assert!(K > 0); + assert_eq!(K, B_K); // Left column size must equal right row size. assert_eq!(K, N); From 779131260916a9ac98758efbc47df6ca6f2acbab Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Tue, 31 Oct 2023 12:54:24 +0100 Subject: [PATCH 07/12] wait_until_completed() outside of inner loop --- examples/mps/matrix-multiplication/main.rs | 28 ++++++++++------------ 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/examples/mps/matrix-multiplication/main.rs b/examples/mps/matrix-multiplication/main.rs index 07e839a..54340c2 100644 --- a/examples/mps/matrix-multiplication/main.rs +++ b/examples/mps/matrix-multiplication/main.rs @@ -73,7 +73,6 @@ fn performance() { // Setup let device = Device::system_default().expect("No device found"); - let command_queue = device.new_command_queue(); let cases = [ (false, false, 1.0, 0.0), @@ -84,14 +83,13 @@ fn performance() { ]; for (t_left, t_right, alpha, beta) in cases { println!("Running with transpose left: {t_left}, transpose right: {t_right}, alpha: {alpha}, beta: {beta}"); - let mut flops: Vec = vec![]; + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); - let mut total_time = std::time::Duration::new(0, 0); + let start = std::time::Instant::now(); for i in 0..ITERATIONS { progress_bar(i, ITERATIONS); - let start = std::time::Instant::now(); - let command_buffer = command_queue.new_command_buffer(); let _ = encode_gemm( &device, command_buffer, @@ -102,24 +100,22 @@ fn performance() { alpha, beta, ); - command_buffer.commit(); - command_buffer.wait_until_completed(); + } + command_buffer.commit(); + command_buffer.wait_until_completed(); - let time = std::time::Instant::now() - start; + let total_time = start.elapsed(); - total_time += time; + // Calculate GFLOPS + // C <- alpha * AB + beta * C + // Operations = 2(M * N * K) + let avg_gflops = (ITERATIONS as u64 * (M * N * (2 * K - 1))) as f64 + / (total_time.as_secs_f64() * 1e+9f64); - // Calculate GFLOPS - // C <- alpha * AB + beta * C - // Operations = 2(M * N * K) - flops.push((M * N * (2 * K + 2)) as f64 / (time.as_secs_f64() * 1e+9f64)); - } println!(" ✅"); - let avg_gflops = flops.iter().sum::() / flops.len() as f64; println!("Avg GFLOPS: {}", avg_gflops); println!("Total time: {:#?}", total_time); - println!("Avg time: {:#?}", total_time / ITERATIONS as u32); println!() } } From 933b3b8d81ec90e3ab28378e74399bda0e38768b Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Tue, 31 Oct 2023 13:48:06 +0100 Subject: [PATCH 08/12] Easier to change matrix types in gemm benchmark --- examples/mps/matrix-multiplication/main.rs | 23 ++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/examples/mps/matrix-multiplication/main.rs b/examples/mps/matrix-multiplication/main.rs index 54340c2..2606269 100644 --- a/examples/mps/matrix-multiplication/main.rs +++ b/examples/mps/matrix-multiplication/main.rs @@ -1,3 +1,4 @@ +use std::any::type_name; use std::io; use std::io::Write; use std::ops::{AddAssign, Mul}; @@ -58,18 +59,32 @@ fn correctness() { println!(" ✅\n"); } +fn short_type_name() -> String { + let name = type_name::(); + let mut parts = name.split("::"); + parts.last().unwrap().to_string() +} + fn performance() { const M: u64 = 4096; const N: u64 = 4096; const K: u64 = 4096; + type A = Float32; + type B = Float16; + type C = Float32; const ITERATIONS: usize = 50; println!("Performance: "); - println!("Generating input matrices: (f32 {M}x{K} and f16 {K}x{N})"); + + let a = short_type_name::(); + let b = short_type_name::(); + let c = short_type_name::(); + println!("{M}x{K}x{a} * {K}x{N}x{b} = {M}x{N}x{c}"); + println!("Generating input matrices..."); // Generate random matrices - let left = generate_matrix::(); - let right = generate_matrix::(); + let left = generate_matrix::(); + let right = generate_matrix::(); // Setup let device = Device::system_default().expect("No device found"); @@ -90,7 +105,7 @@ fn performance() { for i in 0..ITERATIONS { progress_bar(i, ITERATIONS); - let _ = encode_gemm( + let _: MatrixBuffer = encode_gemm( &device, command_buffer, t_left, From 36ae08043549ba0bbf94be447123f0a05bb56480 Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Tue, 31 Oct 2023 19:13:06 +0100 Subject: [PATCH 09/12] Ditch generic Matrix and use MatrixBuffer instead. Create buffers separately of encoding. --- examples/mps/matrix-multiplication/main.rs | 115 +++++++++-------- src/mps.rs | 139 ++++++--------------- 2 files changed, 99 insertions(+), 155 deletions(-) diff --git a/examples/mps/matrix-multiplication/main.rs b/examples/mps/matrix-multiplication/main.rs index 2606269..d362010 100644 --- a/examples/mps/matrix-multiplication/main.rs +++ b/examples/mps/matrix-multiplication/main.rs @@ -15,12 +15,15 @@ fn main() { fn correctness() { // First verify the correctness of the naive solution - let a = Matrix::new([1, 2, 6, 24, 120, 720], 3, 2); - let b = Matrix::new([1, 2, 3, 5, 8, 13], 2, 3); - let result = matrix_mul::(a, b); + let m = 3; + let n = 3; + let k = 2; + let a = vec![1, 2, 6, 24, 120, 720]; + let b = vec![1, 2, 6, 24, 120, 720]; + let result = matrix_mul::(a, b, m, n, k); assert_eq!( - result.entries(), - &[11, 18, 29, 126, 204, 330, 3720, 6000, 9720] + result, + &[49, 242, 1446, 582, 2892, 17316, 17400, 86640, 519120] ); const M: u64 = 100; @@ -35,25 +38,23 @@ fn correctness() { for i in 0..ITERATIONS { progress_bar(i, ITERATIONS); - let left = generate_matrix::(); - let right = generate_matrix::(); + let a = generate_matrix::(&device); + let b = generate_matrix::(&device); + let c = generate_matrix::(&device); let command_buffer = command_queue.new_command_buffer(); - let result = encode_gemm( - &device, - command_buffer, - false, - false, - &left, - &right, - 1.0, - 0.0, - ); + encode_gemm(&device, command_buffer, false, false, &a, &b, &c, 1.0, 0.0); command_buffer.commit(); command_buffer.wait_until_completed(); - let expected = matrix_mul(left, right); - approx_eq(result.contents(), expected.entries().to_vec()); + let expected = matrix_mul::( + a.contents(), + b.contents(), + M as usize, + K as usize, + N as usize, + ); + approx_eq(c.contents(), expected); } println!(" ✅\n"); @@ -61,7 +62,7 @@ fn correctness() { fn short_type_name() -> String { let name = type_name::(); - let mut parts = name.split("::"); + let parts = name.split("::"); parts.last().unwrap().to_string() } @@ -77,18 +78,19 @@ fn performance() { println!("Performance: "); - let a = short_type_name::(); - let b = short_type_name::(); - let c = short_type_name::(); - println!("{M}x{K}x{a} * {K}x{N}x{b} = {M}x{N}x{c}"); - println!("Generating input matrices..."); - // Generate random matrices - let left = generate_matrix::(); - let right = generate_matrix::(); + let a_tname = short_type_name::(); + let b_tname = short_type_name::(); + let c_tname = short_type_name::(); + println!("{M}x{K}x{a_tname} * {K}x{N}x{b_tname} = {M}x{N}x{c_tname}"); - // Setup let device = Device::system_default().expect("No device found"); + println!("Generating input matrices..."); + // Generate random matrices + let a = generate_matrix::(&device); + let b = generate_matrix::(&device); + let c = generate_matrix::(&device); + let cases = [ (false, false, 1.0, 0.0), (true, false, 1.0, 0.0), @@ -105,13 +107,14 @@ fn performance() { for i in 0..ITERATIONS { progress_bar(i, ITERATIONS); - let _: MatrixBuffer = encode_gemm( + encode_gemm( &device, command_buffer, t_left, t_right, - &left, - &right, + &a, + &b, + &c, alpha, beta, ); @@ -135,44 +138,54 @@ fn performance() { } } -fn generate_matrix() -> Matrix +fn generate_matrix(device: &Device) -> MatrixBuffer where T: MPSDataType, GEMMInput: Valid, { let mut rng = thread_rng(); - Matrix::new( - (0..ROWS * COLS).map(|_| T::from_f64(rng.gen())), - ROWS as NSUInteger, - COLS as NSUInteger, - ) + + // Create descriptors for the matrices. + let row_bytes_for_columns = MatrixDescriptor::row_bytes_for_columns(COLS, T::TYPE_ID); + + // Create buffers + let options = MTLResourceOptions::StorageModeShared; + let data = (0..ROWS * COLS) + .map(|_| T::from_f64(rng.gen())) + .collect::>(); + let buffer = + device.new_buffer_with_data(data.as_ptr().cast(), ROWS * row_bytes_for_columns, options); + + MatrixBuffer::from_buffer(buffer, ROWS, COLS) } // Naive matrix multiplication for testing -fn matrix_mul(a: Matrix, b: Matrix) -> Matrix +fn matrix_mul( + a: Vec, + b: Vec, + m: usize, + n: usize, + k: usize, +) -> Vec where T::Type: AddAssign + Mul + Copy, { - assert_eq!(a.columns(), b.rows()); - let sum_count = a.columns() as usize; - let rows = a.rows() as usize; - let columns = b.columns() as usize; - let size = rows * columns; + let size = m * n; - let mut entries = Vec::with_capacity(size); + let mut c = Vec::with_capacity(size); for idx in 0..size { - let i = idx / rows; - let j = idx % columns; + let i = idx / m; + let j = idx % n; let mut sum = T::from_f64(0.0); - for di in 0..sum_count { - sum += a.entry(i, di) * b.entry(di, j); + for di in 0..k { + sum += a[(i * k) + di] * b[(di * n) + j]; } - entries.push(sum); + c.push(sum); } - Matrix::new(entries, a.rows(), b.columns()) + c } fn euclidean_distance(a: Vec, b: Vec) -> f64 diff --git a/src/mps.rs b/src/mps.rs index 672bc1f..0bd6fb0 100644 --- a/src/mps.rs +++ b/src/mps.rs @@ -878,7 +878,7 @@ impl MatrixDescriptor { } } - fn row_bytes_for_columns(columns: NSUInteger, data_type: NSUInteger) -> NSUInteger { + pub fn row_bytes_for_columns(columns: NSUInteger, data_type: NSUInteger) -> NSUInteger { unsafe { msg_send![ class!(MPSMatrixDescriptor), @@ -889,85 +889,26 @@ impl MatrixDescriptor { } } -impl From<&Matrix> for MatrixDescriptor { - fn from(matrix: &Matrix) -> Self { - let data_type = T::TYPE_ID; - // The number of bytes between starting elements of consecutive rows. - let row_bytes = MatrixDescriptor::row_bytes_for_columns(matrix.columns, data_type); - Self::init_single(matrix.rows, matrix.columns, row_bytes, data_type) - } -} - /// See pub enum MPSMatrix {} foreign_obj_type! { type CType = MPSMatrix; - pub struct MatrixObject; + pub struct Matrix; type ParentType = NsObject; } -/// Generic matrix for MPSDataTypes. -#[derive(Debug)] -pub struct Matrix { - entries: Vec, - // row-major order - rows: NSUInteger, - columns: NSUInteger, -} - -impl Matrix { - pub fn new>( - entries: E, - rows: NSUInteger, - columns: NSUInteger, - ) -> Matrix { - let entries: Vec = entries.into_iter().collect(); - assert_eq!(entries.len(), rows as usize * columns as usize); - Self { - entries, - rows, - columns, - } - } - pub fn entries(&self) -> &[T::Type] { - &self.entries - } - - pub fn entry(&self, row: usize, column: usize) -> T::Type { - assert!(row < self.rows as usize); - assert!(column < self.columns as usize); - self.entries[row * self.columns as usize + column] - } - - pub fn rows(&self) -> NSUInteger { - self.rows - } - - pub fn columns(&self) -> NSUInteger { - self.columns - } -} - -impl From> for Matrix { - fn from(buffer: MatrixBuffer) -> Self { - Self::new(buffer.contents(), buffer.rows, buffer.columns) - } -} - -impl Display for Matrix { +impl Display for MatrixBuffer { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - assert_eq!( - self.entries.len(), - self.rows as usize * self.columns as usize - ); + let contents = self.contents(); + assert_eq!(contents.len(), self.rows as usize * self.columns as usize); let mut col = 0; for i in 0..(self.rows * self.columns) as usize { if col == 0 { write!(f, "|")?; } - write!(f, "{:?}", self.entries.get(i).ok_or(std::fmt::Error)?)?; + write!(f, "{:?}", contents.get(i).ok_or(std::fmt::Error)?)?; if col < self.columns as usize - 1 { write!(f, ", ")?; @@ -981,13 +922,13 @@ impl Display for Matrix { } } -impl MatrixObject { +impl Matrix { fn init_with_device_descriptor( device: &DeviceRef, descriptor: &MatrixDescriptorRef, ) -> Option { unsafe { - let matrix: MatrixObject = msg_send![class!(MPSMatrix), alloc]; + let matrix: Matrix = msg_send![class!(MPSMatrix), alloc]; let ptr: *mut Object = msg_send![ matrix.as_ref(), initWithDevice : device @@ -1006,7 +947,7 @@ impl MatrixObject { descriptor: &MatrixDescriptorRef, ) -> Option { unsafe { - let matrix: MatrixObject = msg_send![class!(MPSMatrix), alloc]; + let matrix: Matrix = msg_send![class!(MPSMatrix), alloc]; let ptr: *mut Object = msg_send![ matrix.as_ref(), initWithBuffer : buffer @@ -1021,7 +962,7 @@ impl MatrixObject { } } -impl MatrixObjectRef { +impl MatrixRef { pub fn device(&self) -> &DeviceRef { unsafe { msg_send![self, device] } } @@ -1150,9 +1091,9 @@ impl MatrixMultiplicationRef { pub fn encode_to_command_buffer( &self, command_buffer: &CommandBufferRef, - left_matrix: &MatrixObjectRef, - right_matrix: &MatrixObjectRef, - result_matrix: &MatrixObjectRef, + left_matrix: &MatrixRef, + right_matrix: &MatrixRef, + result_matrix: &MatrixRef, ) { unsafe { let _: () = msg_send!( @@ -1194,6 +1135,17 @@ impl MatrixBuffer { } } + pub fn from_buffer(buffer: Buffer, rows: NSUInteger, columns: NSUInteger) -> Self { + MatrixBuffer { + buffer: buffer.clone(), + rows, + columns, + count: (rows * columns) as usize, + allocated_size: buffer.length() as usize, + _marker: PhantomData, + } + } + pub fn count(&self) -> usize { self.count } @@ -1208,12 +1160,12 @@ pub fn encode_gemm( command_buffer: &CommandBufferRef, transpose_left: bool, transpose_right: bool, - left: &Matrix, - right: &Matrix, + a: &MatrixBuffer, + b: &MatrixBuffer, + c: &MatrixBuffer, alpha: f64, beta: f64, -) -> MatrixBuffer -where +) where A: MPSDataType, B: MPSDataType, C: MPSDataType, @@ -1223,45 +1175,27 @@ where GEMMSpecification: Valid, { let (M, K) = if transpose_left { - (left.columns, left.rows) + (a.columns, a.rows) } else { - (left.rows, left.columns) + (a.rows, a.columns) }; let (N, B_K) = if transpose_right { - (right.rows, right.columns) + (b.rows, b.columns) } else { - (right.columns, right.rows) + (b.columns, b.rows) }; validate_shapes(M, N, K, B_K); - // Create descriptors for the matrices. - let left_row_bytes = MatrixDescriptor::row_bytes_for_columns(K, A::TYPE_ID); - let right_row_bytes = MatrixDescriptor::row_bytes_for_columns(N, B::TYPE_ID); - let result_row_bytes = MatrixDescriptor::row_bytes_for_columns(N, C::TYPE_ID); - - // Create buffers - let options = MTLResourceOptions::StorageModeShared; - let left_buffer = - device.new_buffer_with_data(left.entries.as_ptr().cast(), M * left_row_bytes, options); - let right_buffer = - device.new_buffer_with_data(right.entries.as_ptr().cast(), K * right_row_bytes, options); - - let result_buffer = MatrixBuffer::new(device, M, N, M * result_row_bytes, options); - // Create descriptors let left_descriptor = MatrixDescriptor::init_single(M, K, K * A::SIZE, A::TYPE_ID); let right_descriptor = MatrixDescriptor::init_single(K, N, N * B::SIZE, B::TYPE_ID); let result_descriptor = MatrixDescriptor::init_single(M, N, N * C::SIZE, C::TYPE_ID); // Create matrix objects - let left_matrix = - MatrixObject::init_with_buffer_descriptor(&left_buffer, &left_descriptor).unwrap(); - let right_matrix = - MatrixObject::init_with_buffer_descriptor(&right_buffer, &right_descriptor).unwrap(); - let result_matrix = - MatrixObject::init_with_buffer_descriptor(&result_buffer.buffer, &result_descriptor) - .unwrap(); + let left_matrix = Matrix::init_with_buffer_descriptor(&a.buffer, &left_descriptor).unwrap(); + let right_matrix = Matrix::init_with_buffer_descriptor(&b.buffer, &right_descriptor).unwrap(); + let result_matrix = Matrix::init_with_buffer_descriptor(&c.buffer, &result_descriptor).unwrap(); // Create kernel let matrix_multiplication = MatrixMultiplication::init( @@ -1283,9 +1217,6 @@ where &right_matrix, &result_matrix, ); - - // Return result buffer - result_buffer } fn validate_shapes(M: NSUInteger, N: NSUInteger, K: NSUInteger, B_K: NSUInteger) { From 7910f103d351e347e0a767b9e181388da6861e92 Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Tue, 31 Oct 2023 19:32:03 +0100 Subject: [PATCH 10/12] Mark encode_gemm c parameter as &mut --- examples/mps/matrix-multiplication/main.rs | 18 ++++++++++++++---- src/mps.rs | 2 +- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/examples/mps/matrix-multiplication/main.rs b/examples/mps/matrix-multiplication/main.rs index d362010..95a6e0f 100644 --- a/examples/mps/matrix-multiplication/main.rs +++ b/examples/mps/matrix-multiplication/main.rs @@ -40,10 +40,20 @@ fn correctness() { let a = generate_matrix::(&device); let b = generate_matrix::(&device); - let c = generate_matrix::(&device); + let mut c = generate_matrix::(&device); let command_buffer = command_queue.new_command_buffer(); - encode_gemm(&device, command_buffer, false, false, &a, &b, &c, 1.0, 0.0); + encode_gemm( + &device, + command_buffer, + false, + false, + &a, + &b, + &mut c, + 1.0, + 0.0, + ); command_buffer.commit(); command_buffer.wait_until_completed(); @@ -89,7 +99,7 @@ fn performance() { // Generate random matrices let a = generate_matrix::(&device); let b = generate_matrix::(&device); - let c = generate_matrix::(&device); + let mut c = generate_matrix::(&device); let cases = [ (false, false, 1.0, 0.0), @@ -114,7 +124,7 @@ fn performance() { t_right, &a, &b, - &c, + &mut c, alpha, beta, ); diff --git a/src/mps.rs b/src/mps.rs index 0bd6fb0..2bec3bd 100644 --- a/src/mps.rs +++ b/src/mps.rs @@ -1162,7 +1162,7 @@ pub fn encode_gemm( transpose_right: bool, a: &MatrixBuffer, b: &MatrixBuffer, - c: &MatrixBuffer, + c: &mut MatrixBuffer, alpha: f64, beta: f64, ) where From aaf264762b859578bde18c7a8f9c1eba48834d60 Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Tue, 31 Oct 2023 19:42:06 +0100 Subject: [PATCH 11/12] Return Result from encode_gemm --- examples/mps/matrix-multiplication/main.rs | 6 ++++-- src/mps.rs | 16 +++++++++++----- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/examples/mps/matrix-multiplication/main.rs b/examples/mps/matrix-multiplication/main.rs index 95a6e0f..1268ed1 100644 --- a/examples/mps/matrix-multiplication/main.rs +++ b/examples/mps/matrix-multiplication/main.rs @@ -53,7 +53,8 @@ fn correctness() { &mut c, 1.0, 0.0, - ); + ) + .expect("Encoding failed"); command_buffer.commit(); command_buffer.wait_until_completed(); @@ -127,7 +128,8 @@ fn performance() { &mut c, alpha, beta, - ); + ) + .expect("Encoding failed"); } command_buffer.commit(); command_buffer.wait_until_completed(); diff --git a/src/mps.rs b/src/mps.rs index 2bec3bd..f200d8d 100644 --- a/src/mps.rs +++ b/src/mps.rs @@ -1165,7 +1165,8 @@ pub fn encode_gemm( c: &mut MatrixBuffer, alpha: f64, beta: f64, -) where +) -> Result<(), String> +where A: MPSDataType, B: MPSDataType, C: MPSDataType, @@ -1193,9 +1194,12 @@ pub fn encode_gemm( let result_descriptor = MatrixDescriptor::init_single(M, N, N * C::SIZE, C::TYPE_ID); // Create matrix objects - let left_matrix = Matrix::init_with_buffer_descriptor(&a.buffer, &left_descriptor).unwrap(); - let right_matrix = Matrix::init_with_buffer_descriptor(&b.buffer, &right_descriptor).unwrap(); - let result_matrix = Matrix::init_with_buffer_descriptor(&c.buffer, &result_descriptor).unwrap(); + let left_matrix = Matrix::init_with_buffer_descriptor(&a.buffer, &left_descriptor) + .ok_or_else(|| "Failed to create left matrix")?; + let right_matrix = Matrix::init_with_buffer_descriptor(&b.buffer, &right_descriptor) + .ok_or_else(|| "Failed to create right matrix")?; + let result_matrix = Matrix::init_with_buffer_descriptor(&c.buffer, &result_descriptor) + .ok_or_else(|| "Failed to create result matrix")?; // Create kernel let matrix_multiplication = MatrixMultiplication::init( @@ -1208,7 +1212,7 @@ pub fn encode_gemm( alpha, beta, ) - .unwrap(); + .ok_or_else(|| "Failed to create matrix multiplication kernel")?; // Encode kernel to command buffer matrix_multiplication.encode_to_command_buffer( @@ -1217,6 +1221,8 @@ pub fn encode_gemm( &right_matrix, &result_matrix, ); + + Ok(()) } fn validate_shapes(M: NSUInteger, N: NSUInteger, K: NSUInteger, B_K: NSUInteger) { From ad607682450455ad0606bcaae28fdafdaf9642a6 Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Tue, 31 Oct 2023 19:44:08 +0100 Subject: [PATCH 12/12] MPSDataType TYPE_ID -> u32 --- src/mps.rs | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/mps.rs b/src/mps.rs index f200d8d..5d57833 100644 --- a/src/mps.rs +++ b/src/mps.rs @@ -556,7 +556,7 @@ pub struct MPSIntersectionDistancePrimitiveIndexCoordinates { /// See . pub trait MPSDataType: Clone + Copy + PartialEq + Eq + Debug + Hash { type Type: Default + Clone + Copy + PartialEq + Debug + Sized; - const TYPE_ID: NSUInteger; + const TYPE_ID: u32; /// See . const SIZE: NSUInteger = ((Self::TYPE_ID & 0xFFFF) >> 3) as NSUInteger; @@ -567,24 +567,24 @@ pub trait MPSDataType: Clone + Copy + PartialEq + Eq + Debug + Hash { } /// A common bit for all floating point data types. Zero for integer types -const MPS_FLOATBIT_ENCODING: NSUInteger = 0x10000000; +const MPS_FLOATBIT_ENCODING: u32 = 0x10000000; /// A common bit for all complex point data types. Zero for integer types -const MPS_COMPLEXBIT_ENCODING: NSUInteger = MPS_FLOATBIT_ENCODING | 0x01000000; +const MPS_COMPLEXBIT_ENCODING: u32 = MPS_FLOATBIT_ENCODING | 0x01000000; /// A common bit for all signed data types -const MPS_SIGNEDBIT_ENCODING: NSUInteger = 0x20000000; +const MPS_SIGNEDBIT_ENCODING: u32 = 0x20000000; /// A common bit for all alternate encoding data types -const MPS_ALTERNATE_ENCODING: NSUInteger = 0x80000000; +const MPS_ALTERNATE_ENCODING: u32 = 0x80000000; /// A common bit for all normalized data types. /// If set, the value of the shall be interpreted as value / UNORM_TYPE_MAX /// Normalized values have range [0, 1.0] if unsigned and [-1,1] if signed. /// SNORM_TYPE_MIN is interpreted as SNORM_TYPE_MIN+1 per standard Metal rules. -const MPS_NORMALIZEDBIT_ENCODING: NSUInteger = 0x40000000; +const MPS_NORMALIZEDBIT_ENCODING: u32 = 0x40000000; macro_rules! mps_datatype_impl { ($dt:ident, $dt_ty:ty, $type_id:expr, $from_f64:expr, $to_f64:expr) => { impl MPSDataType for $dt { type Type = $dt_ty; - const TYPE_ID: NSUInteger = $type_id; + const TYPE_ID: u32 = $type_id; fn from_f64(v: f64) -> Self::Type { $from_f64(v) @@ -844,7 +844,7 @@ impl MatrixDescriptor { rows: NSUInteger, columns: NSUInteger, row_bytes: NSUInteger, - data_type: NSUInteger, + data_type: u32, ) -> Self { unsafe { msg_send![ @@ -878,7 +878,7 @@ impl MatrixDescriptor { } } - pub fn row_bytes_for_columns(columns: NSUInteger, data_type: NSUInteger) -> NSUInteger { + pub fn row_bytes_for_columns(columns: NSUInteger, data_type: u32) -> NSUInteger { unsafe { msg_send![ class!(MPSMatrixDescriptor),