Skip to content

Commit

Permalink
Merge pull request #12 from boydjohnson/feature/skeleton-for-gpu
Browse files Browse the repository at this point in the history
Add Memory::to_vec
  • Loading branch information
boydjohnson authored Dec 31, 2024
2 parents bcc457c + b926bbd commit 069b63a
Show file tree
Hide file tree
Showing 7 changed files with 130 additions and 89 deletions.
19 changes: 10 additions & 9 deletions benches/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use {
config::binary::{Binary, ForwardBinaryConfig},
ExecArg, ForwardBinary, Primitive, PropForwardInference,
},
set_primitive_cache_capacity,
stream::Stream,
},
onednnl_sys::{dnnl_data_type_t::dnnl_f32, DNNL_ARG_DST, DNNL_ARG_SRC_0, DNNL_ARG_SRC_1},
Expand All @@ -27,6 +28,8 @@ use {
fn binary_add(b: &mut Bencher) {
let engine = Engine::new(Engine::CPU, 0).unwrap();

set_primitive_cache_capacity(2).unwrap();

let stream = Arc::new(Stream::new(engine.clone()).unwrap());

let src0_desc = MemoryDescriptor::new::<1, x>([3], DataType::F32).unwrap();
Expand All @@ -46,22 +49,20 @@ fn binary_add(b: &mut Bencher) {
assert!(primitive.is_ok());
let primitive = primitive.unwrap();

let mut s0_buffer = AlignedBuffer::new(&[4.0f32, 5.0, 6.0]).unwrap().into();
let s0_buffer = AlignedBuffer::new(&[4.0f32, 5.0, 6.0]).unwrap().into();

// Allocate and initialize memory
let src0_memory =
Memory::new_with_user_buffer(engine.clone(), src0_desc, &mut s0_buffer).unwrap();
let src0_memory = Memory::new_with_user_buffer(engine.clone(), src0_desc, s0_buffer).unwrap();

let mut s1_buffer = AlignedBuffer::new(&[1.0f32, 2.0, 3.0]).unwrap().into();
let s1_buffer = AlignedBuffer::new(&[1.0f32, 2.0, 3.0]).unwrap().into();

let src1_memory =
Memory::new_with_user_buffer(engine.clone(), src1_desc, &mut s1_buffer).unwrap();
let src1_memory = Memory::new_with_user_buffer(engine.clone(), src1_desc, s1_buffer).unwrap();

let mut output = AlignedBuffer::<f32>::zeroed(dst_desc.get_size() / data_type_size(dnnl_f32))
let output = AlignedBuffer::<f32>::zeroed(dst_desc.get_size() / data_type_size(dnnl_f32))
.unwrap()
.into();

let dst_memory = Memory::new_with_user_buffer(engine.clone(), dst_desc, &mut output).unwrap();
let dst_memory = Memory::new_with_user_buffer(engine.clone(), dst_desc, output).unwrap();

b.iter(|| {
// Create the primitive
Expand Down Expand Up @@ -90,6 +91,6 @@ fn binary_add(b: &mut Bencher) {

assert_eq!(result, Ok(()));

assert_eq!(output.to_vec::<f32>(), vec![5.0, 7.0, 9.0]);
assert_eq!(dst_memory.to_vec(), Ok(vec![5.0, 7.0, 9.0]));
});
}
2 changes: 1 addition & 1 deletion src/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ impl Engine {
///
/// assert_eq!(engine.get_kind(), Ok(Engine::CPU));
/// ```
pub fn get_kind(self: Arc<Self>) -> Result<dnnl_engine_kind_t::Type, DnnlError> {
pub fn get_kind(self: &Arc<Self>) -> Result<dnnl_engine_kind_t::Type, DnnlError> {
let mut kind: dnnl_engine_kind_t::Type = 0; // Initialize a variable to store the kind
let status = unsafe { dnnl_engine_get_kind(self.handle, &mut kind) }; // Pass a mutable reference

Expand Down
11 changes: 11 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,15 @@ pub mod memory;
pub mod primitive;
pub mod stream;

use error::DnnlError;
pub use onednnl_sys;

pub fn set_primitive_cache_capacity(capacity: std::ffi::c_int) -> Result<(), DnnlError> {
let status = unsafe { onednnl_sys::dnnl_set_primitive_cache_capacity(capacity) };

if status == onednnl_sys::dnnl_status_t::dnnl_success {
Ok(())
} else {
Err(status.into())
}
}
113 changes: 89 additions & 24 deletions src/memory.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
#![allow(non_snake_case)]
use {
crate::{engine::Engine, error::DnnlError},
buffer::Buffer,
buffer::AlignedBuffer,
descriptor::MemoryDescriptor,
onednnl_sys::{
dnnl_data_type_size, dnnl_data_type_t, dnnl_memory, dnnl_memory_create,
dnnl_memory_destroy, dnnl_memory_t, dnnl_status_t,
dnnl_data_type_size,
dnnl_data_type_t::{self, dnnl_f32},
dnnl_engine_kind_t, dnnl_memory, dnnl_memory_create, dnnl_memory_destroy,
dnnl_memory_get_data_handle, dnnl_memory_t, dnnl_status_t, DNNL_GPU_RUNTIME,
DNNL_RUNTIME_OCL, DNNL_RUNTIME_SYCL,
},
std::{ffi::c_void, sync::Arc},
};

/// Get the size for a data type.
pub fn data_type_size(ty: dnnl_data_type_t::Type) -> usize {
unsafe { dnnl_data_type_size(ty) }
}
Expand All @@ -25,21 +30,21 @@ pub mod descriptor;
pub mod format_tag;

#[derive(Debug)]
pub enum BufferType {
UserAllocated(Buffer),
pub enum BufferType<T> {
UserAllocated(AlignedBuffer<T>),
LibraryAllocated,
None,
}

#[derive(Debug)]
pub struct Memory {
pub struct Memory<T> {
pub(crate) handle: dnnl_memory_t,
pub engine: Arc<Engine>,
pub buffer_type: BufferType,
pub buffer_type: BufferType<T>,
pub desc: MemoryDescriptor,
}

impl Memory {
impl<T> Memory<T> {
/// Creates a new memory object with a user-allocated buffer.
///
/// This function initializes a `Memory` instance using a buffer provided by the user.
Expand Down Expand Up @@ -80,30 +85,49 @@ impl Memory {
///
///
/// let mem_desc = MemoryDescriptor::new::<6, abcdef>(dims, dnnl_f32).unwrap();
/// let mut buffer = AlignedBuffer::<f32>::zeroed(mem_desc.get_size() / data_type_size(dnnl_f32)).unwrap().into();
/// let memory = Memory::new_with_user_buffer(Arc::clone(&engine), mem_desc, &mut buffer);
/// let mut buffer = AlignedBuffer::<f32>::zeroed(mem_desc.get_size() / data_type_size(dnnl_f32)).unwrap();
/// let memory = Memory::new_with_user_buffer(Arc::clone(&engine), mem_desc, buffer);
/// assert!(memory.is_ok());
/// ```
pub fn new_with_user_buffer(
engine: Arc<Engine>,
desc: MemoryDescriptor,
buffer: &mut Buffer,
buffer: AlignedBuffer<T>,
) -> Result<Self, DnnlError> {
let mut handle = std::ptr::null_mut::<dnnl_memory>();
let status = unsafe {
dnnl_memory_create(
&mut handle,
desc.handle,
engine.handle,
buffer.as_ptr() as *mut c_void,
)

let status = match engine.get_kind() {
Ok(dnnl_engine_kind_t::dnnl_cpu) => unsafe {
dnnl_memory_create(
&mut handle,
desc.handle,
engine.handle,
buffer.ptr.as_ptr() as *mut c_void,
)
},
Ok(dnnl_engine_kind_t::dnnl_gpu) => {
if DNNL_GPU_RUNTIME == DNNL_RUNTIME_SYCL {
todo!("Add SYCL interop")
} else if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL {
todo!("Add OCL interop")
} else {
todo!("Return Error for lack of a GPU Runtime")
}
}
Ok(dnnl_engine_kind_t::dnnl_any_engine) => {
todo!("Add DNNL ANY interop")
}
Ok(_) => {
panic!("Unexpected engine kind type type")
}
Err(e) => return Err(e),
};

if status == dnnl_status_t::dnnl_success {
Ok(Memory {
handle,
engine,
buffer_type: BufferType::UserAllocated(buffer.clone()),
buffer_type: BufferType::UserAllocated(buffer),
desc,
})
} else {
Expand All @@ -130,7 +154,7 @@ impl Memory {
/// let mem_desc = MemoryDescriptor::new::<4, abcd>(dims, dnnl_f32).unwrap();
///
///
/// let memory = Memory::new_with_library_buffer(Arc::clone(&engine), mem_desc);
/// let memory = Memory::<f32>::new_with_library_buffer(Arc::clone(&engine), mem_desc);
/// assert!(memory.is_ok());
/// ```
pub fn new_with_library_buffer(
Expand Down Expand Up @@ -178,7 +202,7 @@ impl Memory {
/// let mem_desc = MemoryDescriptor::new::<6, abcdef>(dims, dnnl_f32).unwrap();
///
///
/// let memory = Memory::new_without_buffer(Arc::clone(&engine), mem_desc);
/// let memory = Memory::<f32>::new_without_buffer(Arc::clone(&engine), mem_desc);
/// assert!(memory.is_ok());
/// ```
pub fn new_without_buffer(
Expand All @@ -201,13 +225,54 @@ impl Memory {
Err(status.into())
}
}

pub fn to_vec(&self) -> Result<Vec<T>, DnnlError>
where
T: Clone,
{
match self.engine.get_kind() {
Ok(dnnl_engine_kind_t::dnnl_cpu) => match &self.buffer_type {
BufferType::UserAllocated(buffer) => Ok(buffer.as_slice().to_vec()),
BufferType::LibraryAllocated => {
let buffer = AlignedBuffer::<T>::zeroed(
self.desc.get_size() / unsafe { dnnl_data_type_size(dnnl_f32) },
)
.unwrap();

let status = unsafe {
dnnl_memory_get_data_handle(
self.handle,
buffer.ptr.as_ptr() as *mut *mut c_void,
)
};

if status == dnnl_status_t::dnnl_success {
Ok(buffer.as_slice().to_vec())
} else {
Err(status.into())
}
}
BufferType::None => todo!("return error"),
},
Ok(dnnl_engine_kind_t::dnnl_gpu) => {
todo!("Return the right data")
}
Ok(dnnl_engine_kind_t::dnnl_any_engine) => {
todo!("Return the right data")
}
Ok(t) => {
panic!("Received incorrect engine_kind_t: {}", t)
}
Err(e) => Err(e),
}
}
}

impl Drop for Memory {
impl<T> Drop for Memory<T> {
fn drop(&mut self) {
unsafe { dnnl_memory_destroy(self.handle) };
}
}

unsafe impl Sync for Memory {}
unsafe impl Send for Memory {}
unsafe impl<T> Sync for Memory<T> {}
unsafe impl<T> Send for Memory<T> {}
33 changes: 0 additions & 33 deletions src/memory/buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use {
std::{
alloc::{alloc, alloc_zeroed, dealloc, Layout},
ptr::NonNull,
sync::{Arc, RwLock},
},
};

Expand Down Expand Up @@ -66,35 +65,3 @@ impl<T> Drop for AlignedBuffer<T> {
}
}
}

#[derive(Debug, Clone)]
pub enum Buffer {
FloatBuffer(Arc<RwLock<AlignedBuffer<f32>>>),
}

impl From<AlignedBuffer<f32>> for Buffer {
fn from(value: AlignedBuffer<f32>) -> Self {
Self::FloatBuffer(Arc::new(RwLock::new(value)))
}
}

impl Buffer {
pub(crate) fn as_ptr<T>(&mut self) -> *mut T {
match self {
Self::FloatBuffer(buffer) => buffer.write().unwrap().ptr.as_ptr() as *mut T,
}
}

pub fn to_vec<T>(&self) -> Vec<T>
where
T: From<f32>,
{
match self {
Self::FloatBuffer(buffer) => {
let guard = buffer.read().unwrap();

guard.as_slice().iter().copied().map(|t| t.into()).collect()
}
}
}
}
6 changes: 3 additions & 3 deletions src/primitive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ impl Primitive {
}
}

pub fn execute(&self, stream: &Stream, args: Vec<ExecArg<'_>>) -> Result<(), DnnlError> {
pub fn execute<T>(&self, stream: &Stream, args: Vec<ExecArg<'_, T>>) -> Result<(), DnnlError> {
let c_args: Vec<dnnl_exec_arg_t> = args
.iter()
.map(|arg| dnnl_exec_arg_t {
Expand Down Expand Up @@ -265,7 +265,7 @@ impl Drop for Primitive {
}
}

pub struct ExecArg<'a> {
pub struct ExecArg<'a, T> {
pub index: i32,
pub mem: &'a Memory,
pub mem: &'a Memory<T>,
}
Loading

0 comments on commit 069b63a

Please sign in to comment.