Skip to content

Commit

Permalink
Memory takes an owned AlignedBuffer
Browse files Browse the repository at this point in the history
  • Loading branch information
boydjohnson committed Dec 31, 2024
1 parent 30aff54 commit b926bbd
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 83 deletions.
16 changes: 7 additions & 9 deletions benches/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,22 +49,20 @@ fn binary_add(b: &mut Bencher) {
assert!(primitive.is_ok());
let primitive = primitive.unwrap();

let mut s0_buffer = AlignedBuffer::new(&[4.0f32, 5.0, 6.0]).unwrap().into();
let s0_buffer = AlignedBuffer::new(&[4.0f32, 5.0, 6.0]).unwrap().into();

// Allocate and initialize memory
let src0_memory =
Memory::new_with_user_buffer(engine.clone(), src0_desc, &mut s0_buffer).unwrap();
let src0_memory = Memory::new_with_user_buffer(engine.clone(), src0_desc, s0_buffer).unwrap();

let mut s1_buffer = AlignedBuffer::new(&[1.0f32, 2.0, 3.0]).unwrap().into();
let s1_buffer = AlignedBuffer::new(&[1.0f32, 2.0, 3.0]).unwrap().into();

let src1_memory =
Memory::new_with_user_buffer(engine.clone(), src1_desc, &mut s1_buffer).unwrap();
let src1_memory = Memory::new_with_user_buffer(engine.clone(), src1_desc, s1_buffer).unwrap();

let mut output = AlignedBuffer::<f32>::zeroed(dst_desc.get_size() / data_type_size(dnnl_f32))
let output = AlignedBuffer::<f32>::zeroed(dst_desc.get_size() / data_type_size(dnnl_f32))
.unwrap()
.into();

let dst_memory = Memory::new_with_user_buffer(engine.clone(), dst_desc, &mut output).unwrap();
let dst_memory = Memory::new_with_user_buffer(engine.clone(), dst_desc, output).unwrap();

b.iter(|| {
// Create the primitive
Expand Down Expand Up @@ -93,6 +91,6 @@ fn binary_add(b: &mut Bencher) {

assert_eq!(result, Ok(()));

assert_eq!(output.to_vec::<f32>(), vec![5.0, 7.0, 9.0]);
assert_eq!(dst_memory.to_vec(), Ok(vec![5.0, 7.0, 9.0]));
});
}
81 changes: 62 additions & 19 deletions src/memory.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
#![allow(non_snake_case)]
use {
crate::{engine::Engine, error::DnnlError},
buffer::Buffer,
buffer::AlignedBuffer,
descriptor::MemoryDescriptor,
onednnl_sys::{
dnnl_data_type_size, dnnl_data_type_t, dnnl_engine_kind_t, dnnl_memory, dnnl_memory_create,
dnnl_memory_destroy, dnnl_memory_t, dnnl_status_t, DNNL_GPU_RUNTIME, DNNL_RUNTIME_OCL,
DNNL_RUNTIME_SYCL,
dnnl_data_type_size,
dnnl_data_type_t::{self, dnnl_f32},
dnnl_engine_kind_t, dnnl_memory, dnnl_memory_create, dnnl_memory_destroy,
dnnl_memory_get_data_handle, dnnl_memory_t, dnnl_status_t, DNNL_GPU_RUNTIME,
DNNL_RUNTIME_OCL, DNNL_RUNTIME_SYCL,
},
std::{ffi::c_void, sync::Arc},
};
Expand All @@ -28,21 +30,21 @@ pub mod descriptor;
pub mod format_tag;

#[derive(Debug)]
pub enum BufferType {
UserAllocated(Buffer),
pub enum BufferType<T> {
UserAllocated(AlignedBuffer<T>),
LibraryAllocated,
None,
}

#[derive(Debug)]
pub struct Memory {
pub struct Memory<T> {
pub(crate) handle: dnnl_memory_t,
pub engine: Arc<Engine>,
pub buffer_type: BufferType,
pub buffer_type: BufferType<T>,
pub desc: MemoryDescriptor,
}

impl Memory {
impl<T> Memory<T> {
/// Creates a new memory object with a user-allocated buffer.
///
/// This function initializes a `Memory` instance using a buffer provided by the user.
Expand Down Expand Up @@ -83,14 +85,14 @@ impl Memory {
///
///
/// let mem_desc = MemoryDescriptor::new::<6, abcdef>(dims, dnnl_f32).unwrap();
/// let mut buffer = AlignedBuffer::<f32>::zeroed(mem_desc.get_size() / data_type_size(dnnl_f32)).unwrap().into();
/// let memory = Memory::new_with_user_buffer(Arc::clone(&engine), mem_desc, &mut buffer);
/// let mut buffer = AlignedBuffer::<f32>::zeroed(mem_desc.get_size() / data_type_size(dnnl_f32)).unwrap();
/// let memory = Memory::new_with_user_buffer(Arc::clone(&engine), mem_desc, buffer);
/// assert!(memory.is_ok());
/// ```
pub fn new_with_user_buffer(
engine: Arc<Engine>,
desc: MemoryDescriptor,
buffer: &mut Buffer,
buffer: AlignedBuffer<T>,
) -> Result<Self, DnnlError> {
let mut handle = std::ptr::null_mut::<dnnl_memory>();

Expand All @@ -100,7 +102,7 @@ impl Memory {
&mut handle,
desc.handle,
engine.handle,
buffer.as_ptr() as *mut c_void,
buffer.ptr.as_ptr() as *mut c_void,
)
},
Ok(dnnl_engine_kind_t::dnnl_gpu) => {
Expand All @@ -125,7 +127,7 @@ impl Memory {
Ok(Memory {
handle,
engine,
buffer_type: BufferType::UserAllocated(buffer.clone()),
buffer_type: BufferType::UserAllocated(buffer),
desc,
})
} else {
Expand All @@ -152,7 +154,7 @@ impl Memory {
/// let mem_desc = MemoryDescriptor::new::<4, abcd>(dims, dnnl_f32).unwrap();
///
///
/// let memory = Memory::new_with_library_buffer(Arc::clone(&engine), mem_desc);
/// let memory = Memory::<f32>::new_with_library_buffer(Arc::clone(&engine), mem_desc);
/// assert!(memory.is_ok());
/// ```
pub fn new_with_library_buffer(
Expand Down Expand Up @@ -200,7 +202,7 @@ impl Memory {
/// let mem_desc = MemoryDescriptor::new::<6, abcdef>(dims, dnnl_f32).unwrap();
///
///
/// let memory = Memory::new_without_buffer(Arc::clone(&engine), mem_desc);
/// let memory = Memory::<f32>::new_without_buffer(Arc::clone(&engine), mem_desc);
/// assert!(memory.is_ok());
/// ```
pub fn new_without_buffer(
Expand All @@ -223,13 +225,54 @@ impl Memory {
Err(status.into())
}
}

pub fn to_vec(&self) -> Result<Vec<T>, DnnlError>
where
T: Clone,
{
match self.engine.get_kind() {
Ok(dnnl_engine_kind_t::dnnl_cpu) => match &self.buffer_type {
BufferType::UserAllocated(buffer) => Ok(buffer.as_slice().to_vec()),
BufferType::LibraryAllocated => {
let buffer = AlignedBuffer::<T>::zeroed(
self.desc.get_size() / unsafe { dnnl_data_type_size(dnnl_f32) },
)
.unwrap();

let status = unsafe {
dnnl_memory_get_data_handle(
self.handle,
buffer.ptr.as_ptr() as *mut *mut c_void,
)
};

if status == dnnl_status_t::dnnl_success {
Ok(buffer.as_slice().to_vec())
} else {
Err(status.into())
}
}
BufferType::None => todo!("return error"),
},
Ok(dnnl_engine_kind_t::dnnl_gpu) => {
todo!("Return the right data")
}
Ok(dnnl_engine_kind_t::dnnl_any_engine) => {
todo!("Return the right data")
}
Ok(t) => {
panic!("Received incorrect engine_kind_t: {}", t)
}
Err(e) => Err(e),
}
}
}

impl Drop for Memory {
impl<T> Drop for Memory<T> {
fn drop(&mut self) {
unsafe { dnnl_memory_destroy(self.handle) };
}
}

unsafe impl Sync for Memory {}
unsafe impl Send for Memory {}
unsafe impl<T> Sync for Memory<T> {}
unsafe impl<T> Send for Memory<T> {}
33 changes: 0 additions & 33 deletions src/memory/buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use {
std::{
alloc::{alloc, alloc_zeroed, dealloc, Layout},
ptr::NonNull,
sync::{Arc, RwLock},
},
};

Expand Down Expand Up @@ -66,35 +65,3 @@ impl<T> Drop for AlignedBuffer<T> {
}
}
}

#[derive(Debug, Clone)]
pub enum Buffer {
FloatBuffer(Arc<RwLock<AlignedBuffer<f32>>>),
}

impl From<AlignedBuffer<f32>> for Buffer {
fn from(value: AlignedBuffer<f32>) -> Self {
Self::FloatBuffer(Arc::new(RwLock::new(value)))
}
}

impl Buffer {
pub(crate) fn as_ptr<T>(&mut self) -> *mut T {
match self {
Self::FloatBuffer(buffer) => buffer.write().unwrap().ptr.as_ptr() as *mut T,
}
}

pub fn to_vec<T>(&self) -> Vec<T>
where
T: From<f32>,
{
match self {
Self::FloatBuffer(buffer) => {
let guard = buffer.read().unwrap();

guard.as_slice().iter().copied().map(|t| t.into()).collect()
}
}
}
}
6 changes: 3 additions & 3 deletions src/primitive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ impl Primitive {
}
}

pub fn execute(&self, stream: &Stream, args: Vec<ExecArg<'_>>) -> Result<(), DnnlError> {
pub fn execute<T>(&self, stream: &Stream, args: Vec<ExecArg<'_, T>>) -> Result<(), DnnlError> {
let c_args: Vec<dnnl_exec_arg_t> = args
.iter()
.map(|arg| dnnl_exec_arg_t {
Expand Down Expand Up @@ -265,7 +265,7 @@ impl Drop for Primitive {
}
}

pub struct ExecArg<'a> {
pub struct ExecArg<'a, T> {
pub index: i32,
pub mem: &'a Memory,
pub mem: &'a Memory<T>,
}
35 changes: 16 additions & 19 deletions tests/test_smoke.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,22 +46,20 @@ pub fn test_smoke_binary_add() {
assert!(primitive.is_ok());
let primitive = primitive.unwrap();

let mut s0_buffer = AlignedBuffer::new(&[4.0f32, 5.0, 6.0]).unwrap().into();
let s0_buffer = AlignedBuffer::new(&[4.0f32, 5.0, 6.0]).unwrap().into();

// Allocate and initialize memory
let src0_memory =
Memory::new_with_user_buffer(engine.clone(), src0_desc, &mut s0_buffer).unwrap();
let src0_memory = Memory::new_with_user_buffer(engine.clone(), src0_desc, s0_buffer).unwrap();

let mut s1_buffer = AlignedBuffer::new(&[1.0f32, 2.0, 3.0]).unwrap().into();
let s1_buffer = AlignedBuffer::new(&[1.0f32, 2.0, 3.0]).unwrap().into();

let src1_memory =
Memory::new_with_user_buffer(engine.clone(), src1_desc, &mut s1_buffer).unwrap();
let src1_memory = Memory::new_with_user_buffer(engine.clone(), src1_desc, s1_buffer).unwrap();

let mut output = AlignedBuffer::<f32>::zeroed(dst_desc.get_size() / data_type_size(dnnl_f32))
let output = AlignedBuffer::<f32>::zeroed(dst_desc.get_size() / data_type_size(dnnl_f32))
.unwrap()
.into();

let dst_memory = Memory::new_with_user_buffer(engine.clone(), dst_desc, &mut output).unwrap();
let dst_memory = Memory::new_with_user_buffer(engine.clone(), dst_desc, output).unwrap();

// Configure the binary operation

Expand All @@ -88,7 +86,7 @@ pub fn test_smoke_binary_add() {

assert_eq!(result, Ok(()));

assert_eq!(output.to_vec::<f32>(), vec![5.0, 7.0, 9.0]);
assert_eq!(dst_memory.to_vec(), Ok(vec![5.0, 7.0, 9.0]));
}

#[test]
Expand All @@ -115,13 +113,13 @@ pub fn test_smoke_matmul() {

// Step 3: Allocate Aligned Buffers
// Initialize src and weights with sample data, dst with zeros
let mut src_buffer = AlignedBuffer::new(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0])
let src_buffer = AlignedBuffer::new(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0])
.expect("Failed to allocate src buffer")
.into();
let mut weights_buffer = AlignedBuffer::new(&[7.0f32, 8.0, 9.0, 10.0, 11.0, 12.0])
let weights_buffer = AlignedBuffer::new(&[7.0f32, 8.0, 9.0, 10.0, 11.0, 12.0])
.expect("Failed to allocate weights buffer")
.into();
let mut output_buffer =
let output_buffer =
AlignedBuffer::<f32>::zeroed(dst_desc.get_size() / data_type_size(dnnl_f32))
.expect("Failed to allocate output buffer")
.into();
Expand All @@ -146,17 +144,16 @@ pub fn test_smoke_matmul() {

// Step 6: Create Memory Objects
// Wrap the buffers into oneDNN Memory objects
let src_memory = Memory::new_with_user_buffer(engine.clone(), src_desc, &mut src_buffer)
let src_memory = Memory::new_with_user_buffer(engine.clone(), src_desc, src_buffer)
.expect("Failed to create src memory");
let weights_memory =
Memory::new_with_user_buffer(engine.clone(), weights_desc, &mut weights_buffer)
.expect("Failed to create weights memory");
let weights_memory = Memory::new_with_user_buffer(engine.clone(), weights_desc, weights_buffer)
.expect("Failed to create weights memory");

// Since we are disabling bias, create a Memory object without a buffer
let bias_memory = Memory::new_without_buffer(engine.clone(), zero_bias_desc)
.expect("Failed to create bias memory (disabled)");

let dst_memory = Memory::new_with_user_buffer(engine.clone(), dst_desc, &mut output_buffer)
let dst_memory = Memory::new_with_user_buffer(engine.clone(), dst_desc, output_buffer)
.expect("Failed to create destination memory");

// Step 7: Create a Stream
Expand Down Expand Up @@ -205,8 +202,8 @@ pub fn test_smoke_matmul() {

let expected = vec![58.0f32, 64.0, 139.0, 154.0];
assert_eq!(
output_buffer.to_vec::<f32>(),
expected,
dst_memory.to_vec(),
Ok(expected),
"MatMul output does not match expected results"
);
}
Expand Down

0 comments on commit b926bbd

Please sign in to comment.