From b926bbda039aeb644608399a2b56cfb0a7de6127 Mon Sep 17 00:00:00 2001 From: Boyd Johnson Date: Tue, 31 Dec 2024 20:00:22 +0000 Subject: [PATCH] Memory takes an owned AlignedBuffer --- benches/binary.rs | 16 ++++----- src/memory.rs | 81 +++++++++++++++++++++++++++++++++----------- src/memory/buffer.rs | 33 ------------------ src/primitive.rs | 6 ++-- tests/test_smoke.rs | 35 +++++++++---------- 5 files changed, 88 insertions(+), 83 deletions(-) diff --git a/benches/binary.rs b/benches/binary.rs index c196812..ffc0bcb 100644 --- a/benches/binary.rs +++ b/benches/binary.rs @@ -49,22 +49,20 @@ fn binary_add(b: &mut Bencher) { assert!(primitive.is_ok()); let primitive = primitive.unwrap(); - let mut s0_buffer = AlignedBuffer::new(&[4.0f32, 5.0, 6.0]).unwrap().into(); + let s0_buffer = AlignedBuffer::new(&[4.0f32, 5.0, 6.0]).unwrap().into(); // Allocate and initialize memory - let src0_memory = - Memory::new_with_user_buffer(engine.clone(), src0_desc, &mut s0_buffer).unwrap(); + let src0_memory = Memory::new_with_user_buffer(engine.clone(), src0_desc, s0_buffer).unwrap(); - let mut s1_buffer = AlignedBuffer::new(&[1.0f32, 2.0, 3.0]).unwrap().into(); + let s1_buffer = AlignedBuffer::new(&[1.0f32, 2.0, 3.0]).unwrap().into(); - let src1_memory = - Memory::new_with_user_buffer(engine.clone(), src1_desc, &mut s1_buffer).unwrap(); + let src1_memory = Memory::new_with_user_buffer(engine.clone(), src1_desc, s1_buffer).unwrap(); - let mut output = AlignedBuffer::::zeroed(dst_desc.get_size() / data_type_size(dnnl_f32)) + let output = AlignedBuffer::::zeroed(dst_desc.get_size() / data_type_size(dnnl_f32)) .unwrap() .into(); - let dst_memory = Memory::new_with_user_buffer(engine.clone(), dst_desc, &mut output).unwrap(); + let dst_memory = Memory::new_with_user_buffer(engine.clone(), dst_desc, output).unwrap(); b.iter(|| { // Create the primitive @@ -93,6 +91,6 @@ fn binary_add(b: &mut Bencher) { assert_eq!(result, Ok(())); - assert_eq!(output.to_vec::(), vec![5.0, 7.0, 9.0]); + assert_eq!(dst_memory.to_vec(), Ok(vec![5.0, 7.0, 9.0])); }); } diff --git a/src/memory.rs b/src/memory.rs index e504c75..df31c23 100644 --- a/src/memory.rs +++ b/src/memory.rs @@ -1,12 +1,14 @@ #![allow(non_snake_case)] use { crate::{engine::Engine, error::DnnlError}, - buffer::Buffer, + buffer::AlignedBuffer, descriptor::MemoryDescriptor, onednnl_sys::{ - dnnl_data_type_size, dnnl_data_type_t, dnnl_engine_kind_t, dnnl_memory, dnnl_memory_create, - dnnl_memory_destroy, dnnl_memory_t, dnnl_status_t, DNNL_GPU_RUNTIME, DNNL_RUNTIME_OCL, - DNNL_RUNTIME_SYCL, + dnnl_data_type_size, + dnnl_data_type_t::{self, dnnl_f32}, + dnnl_engine_kind_t, dnnl_memory, dnnl_memory_create, dnnl_memory_destroy, + dnnl_memory_get_data_handle, dnnl_memory_t, dnnl_status_t, DNNL_GPU_RUNTIME, + DNNL_RUNTIME_OCL, DNNL_RUNTIME_SYCL, }, std::{ffi::c_void, sync::Arc}, }; @@ -28,21 +30,21 @@ pub mod descriptor; pub mod format_tag; #[derive(Debug)] -pub enum BufferType { - UserAllocated(Buffer), +pub enum BufferType { + UserAllocated(AlignedBuffer), LibraryAllocated, None, } #[derive(Debug)] -pub struct Memory { +pub struct Memory { pub(crate) handle: dnnl_memory_t, pub engine: Arc, - pub buffer_type: BufferType, + pub buffer_type: BufferType, pub desc: MemoryDescriptor, } -impl Memory { +impl Memory { /// Creates a new memory object with a user-allocated buffer. /// /// This function initializes a `Memory` instance using a buffer provided by the user. @@ -83,14 +85,14 @@ impl Memory { /// /// /// let mem_desc = MemoryDescriptor::new::<6, abcdef>(dims, dnnl_f32).unwrap(); - /// let mut buffer = AlignedBuffer::::zeroed(mem_desc.get_size() / data_type_size(dnnl_f32)).unwrap().into(); - /// let memory = Memory::new_with_user_buffer(Arc::clone(&engine), mem_desc, &mut buffer); + /// let mut buffer = AlignedBuffer::::zeroed(mem_desc.get_size() / data_type_size(dnnl_f32)).unwrap(); + /// let memory = Memory::new_with_user_buffer(Arc::clone(&engine), mem_desc, buffer); /// assert!(memory.is_ok()); /// ``` pub fn new_with_user_buffer( engine: Arc, desc: MemoryDescriptor, - buffer: &mut Buffer, + buffer: AlignedBuffer, ) -> Result { let mut handle = std::ptr::null_mut::(); @@ -100,7 +102,7 @@ impl Memory { &mut handle, desc.handle, engine.handle, - buffer.as_ptr() as *mut c_void, + buffer.ptr.as_ptr() as *mut c_void, ) }, Ok(dnnl_engine_kind_t::dnnl_gpu) => { @@ -125,7 +127,7 @@ impl Memory { Ok(Memory { handle, engine, - buffer_type: BufferType::UserAllocated(buffer.clone()), + buffer_type: BufferType::UserAllocated(buffer), desc, }) } else { @@ -152,7 +154,7 @@ impl Memory { /// let mem_desc = MemoryDescriptor::new::<4, abcd>(dims, dnnl_f32).unwrap(); /// /// - /// let memory = Memory::new_with_library_buffer(Arc::clone(&engine), mem_desc); + /// let memory = Memory::::new_with_library_buffer(Arc::clone(&engine), mem_desc); /// assert!(memory.is_ok()); /// ``` pub fn new_with_library_buffer( @@ -200,7 +202,7 @@ impl Memory { /// let mem_desc = MemoryDescriptor::new::<6, abcdef>(dims, dnnl_f32).unwrap(); /// /// - /// let memory = Memory::new_without_buffer(Arc::clone(&engine), mem_desc); + /// let memory = Memory::::new_without_buffer(Arc::clone(&engine), mem_desc); /// assert!(memory.is_ok()); /// ``` pub fn new_without_buffer( @@ -223,13 +225,54 @@ impl Memory { Err(status.into()) } } + + pub fn to_vec(&self) -> Result, DnnlError> + where + T: Clone, + { + match self.engine.get_kind() { + Ok(dnnl_engine_kind_t::dnnl_cpu) => match &self.buffer_type { + BufferType::UserAllocated(buffer) => Ok(buffer.as_slice().to_vec()), + BufferType::LibraryAllocated => { + let buffer = AlignedBuffer::::zeroed( + self.desc.get_size() / unsafe { dnnl_data_type_size(dnnl_f32) }, + ) + .unwrap(); + + let status = unsafe { + dnnl_memory_get_data_handle( + self.handle, + buffer.ptr.as_ptr() as *mut *mut c_void, + ) + }; + + if status == dnnl_status_t::dnnl_success { + Ok(buffer.as_slice().to_vec()) + } else { + Err(status.into()) + } + } + BufferType::None => todo!("return error"), + }, + Ok(dnnl_engine_kind_t::dnnl_gpu) => { + todo!("Return the right data") + } + Ok(dnnl_engine_kind_t::dnnl_any_engine) => { + todo!("Return the right data") + } + Ok(t) => { + panic!("Received incorrect engine_kind_t: {}", t) + } + Err(e) => Err(e), + } + } } -impl Drop for Memory { +impl Drop for Memory { fn drop(&mut self) { unsafe { dnnl_memory_destroy(self.handle) }; } } -unsafe impl Sync for Memory {} -unsafe impl Send for Memory {} +unsafe impl Sync for Memory {} +unsafe impl Send for Memory {} diff --git a/src/memory/buffer.rs b/src/memory/buffer.rs index 63c0df5..797fd3e 100644 --- a/src/memory/buffer.rs +++ b/src/memory/buffer.rs @@ -3,7 +3,6 @@ use { std::{ alloc::{alloc, alloc_zeroed, dealloc, Layout}, ptr::NonNull, - sync::{Arc, RwLock}, }, }; @@ -66,35 +65,3 @@ impl Drop for AlignedBuffer { } } } - -#[derive(Debug, Clone)] -pub enum Buffer { - FloatBuffer(Arc>>), -} - -impl From> for Buffer { - fn from(value: AlignedBuffer) -> Self { - Self::FloatBuffer(Arc::new(RwLock::new(value))) - } -} - -impl Buffer { - pub(crate) fn as_ptr(&mut self) -> *mut T { - match self { - Self::FloatBuffer(buffer) => buffer.write().unwrap().ptr.as_ptr() as *mut T, - } - } - - pub fn to_vec(&self) -> Vec - where - T: From, - { - match self { - Self::FloatBuffer(buffer) => { - let guard = buffer.read().unwrap(); - - guard.as_slice().iter().copied().map(|t| t.into()).collect() - } - } - } -} diff --git a/src/primitive.rs b/src/primitive.rs index 4b0320a..a74fff2 100644 --- a/src/primitive.rs +++ b/src/primitive.rs @@ -231,7 +231,7 @@ impl Primitive { } } - pub fn execute(&self, stream: &Stream, args: Vec>) -> Result<(), DnnlError> { + pub fn execute(&self, stream: &Stream, args: Vec>) -> Result<(), DnnlError> { let c_args: Vec = args .iter() .map(|arg| dnnl_exec_arg_t { @@ -265,7 +265,7 @@ impl Drop for Primitive { } } -pub struct ExecArg<'a> { +pub struct ExecArg<'a, T> { pub index: i32, - pub mem: &'a Memory, + pub mem: &'a Memory, } diff --git a/tests/test_smoke.rs b/tests/test_smoke.rs index c8b03da..8927803 100644 --- a/tests/test_smoke.rs +++ b/tests/test_smoke.rs @@ -46,22 +46,20 @@ pub fn test_smoke_binary_add() { assert!(primitive.is_ok()); let primitive = primitive.unwrap(); - let mut s0_buffer = AlignedBuffer::new(&[4.0f32, 5.0, 6.0]).unwrap().into(); + let s0_buffer = AlignedBuffer::new(&[4.0f32, 5.0, 6.0]).unwrap().into(); // Allocate and initialize memory - let src0_memory = - Memory::new_with_user_buffer(engine.clone(), src0_desc, &mut s0_buffer).unwrap(); + let src0_memory = Memory::new_with_user_buffer(engine.clone(), src0_desc, s0_buffer).unwrap(); - let mut s1_buffer = AlignedBuffer::new(&[1.0f32, 2.0, 3.0]).unwrap().into(); + let s1_buffer = AlignedBuffer::new(&[1.0f32, 2.0, 3.0]).unwrap().into(); - let src1_memory = - Memory::new_with_user_buffer(engine.clone(), src1_desc, &mut s1_buffer).unwrap(); + let src1_memory = Memory::new_with_user_buffer(engine.clone(), src1_desc, s1_buffer).unwrap(); - let mut output = AlignedBuffer::::zeroed(dst_desc.get_size() / data_type_size(dnnl_f32)) + let output = AlignedBuffer::::zeroed(dst_desc.get_size() / data_type_size(dnnl_f32)) .unwrap() .into(); - let dst_memory = Memory::new_with_user_buffer(engine.clone(), dst_desc, &mut output).unwrap(); + let dst_memory = Memory::new_with_user_buffer(engine.clone(), dst_desc, output).unwrap(); // Configure the binary operation @@ -88,7 +86,7 @@ pub fn test_smoke_binary_add() { assert_eq!(result, Ok(())); - assert_eq!(output.to_vec::(), vec![5.0, 7.0, 9.0]); + assert_eq!(dst_memory.to_vec(), Ok(vec![5.0, 7.0, 9.0])); } #[test] @@ -115,13 +113,13 @@ pub fn test_smoke_matmul() { // Step 3: Allocate Aligned Buffers // Initialize src and weights with sample data, dst with zeros - let mut src_buffer = AlignedBuffer::new(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]) + let src_buffer = AlignedBuffer::new(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]) .expect("Failed to allocate src buffer") .into(); - let mut weights_buffer = AlignedBuffer::new(&[7.0f32, 8.0, 9.0, 10.0, 11.0, 12.0]) + let weights_buffer = AlignedBuffer::new(&[7.0f32, 8.0, 9.0, 10.0, 11.0, 12.0]) .expect("Failed to allocate weights buffer") .into(); - let mut output_buffer = + let output_buffer = AlignedBuffer::::zeroed(dst_desc.get_size() / data_type_size(dnnl_f32)) .expect("Failed to allocate output buffer") .into(); @@ -146,17 +144,16 @@ pub fn test_smoke_matmul() { // Step 6: Create Memory Objects // Wrap the buffers into oneDNN Memory objects - let src_memory = Memory::new_with_user_buffer(engine.clone(), src_desc, &mut src_buffer) + let src_memory = Memory::new_with_user_buffer(engine.clone(), src_desc, src_buffer) .expect("Failed to create src memory"); - let weights_memory = - Memory::new_with_user_buffer(engine.clone(), weights_desc, &mut weights_buffer) - .expect("Failed to create weights memory"); + let weights_memory = Memory::new_with_user_buffer(engine.clone(), weights_desc, weights_buffer) + .expect("Failed to create weights memory"); // Since we are disabling bias, create a Memory object without a buffer let bias_memory = Memory::new_without_buffer(engine.clone(), zero_bias_desc) .expect("Failed to create bias memory (disabled)"); - let dst_memory = Memory::new_with_user_buffer(engine.clone(), dst_desc, &mut output_buffer) + let dst_memory = Memory::new_with_user_buffer(engine.clone(), dst_desc, output_buffer) .expect("Failed to create destination memory"); // Step 7: Create a Stream @@ -205,8 +202,8 @@ pub fn test_smoke_matmul() { let expected = vec![58.0f32, 64.0, 139.0, 154.0]; assert_eq!( - output_buffer.to_vec::(), - expected, + dst_memory.to_vec(), + Ok(expected), "MatMul output does not match expected results" ); }