Skip to content

Commit

Permalink
Add GPU SYCL/OCL specific code
Browse files Browse the repository at this point in the history
  • Loading branch information
boydjohnson committed Dec 31, 2024
1 parent f78af1a commit 30aff54
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 10 deletions.
2 changes: 1 addition & 1 deletion src/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ impl Engine {
///
/// assert_eq!(engine.get_kind(), Ok(Engine::CPU));
/// ```
pub fn get_kind(self: Arc<Self>) -> Result<dnnl_engine_kind_t::Type, DnnlError> {
pub fn get_kind(self: &Arc<Self>) -> Result<dnnl_engine_kind_t::Type, DnnlError> {
let mut kind: dnnl_engine_kind_t::Type = 0; // Initialize a variable to store the kind
let status = unsafe { dnnl_engine_get_kind(self.handle, &mut kind) }; // Pass a mutable reference

Expand Down
40 changes: 31 additions & 9 deletions src/memory.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
#![allow(non_snake_case)]
use {
crate::{engine::Engine, error::DnnlError},
buffer::Buffer,
descriptor::MemoryDescriptor,
onednnl_sys::{
dnnl_data_type_size, dnnl_data_type_t, dnnl_memory, dnnl_memory_create,
dnnl_memory_destroy, dnnl_memory_t, dnnl_status_t,
dnnl_data_type_size, dnnl_data_type_t, dnnl_engine_kind_t, dnnl_memory, dnnl_memory_create,
dnnl_memory_destroy, dnnl_memory_t, dnnl_status_t, DNNL_GPU_RUNTIME, DNNL_RUNTIME_OCL,
DNNL_RUNTIME_SYCL,
},
std::{ffi::c_void, sync::Arc},
};

/// Get the size for a data type.
pub fn data_type_size(ty: dnnl_data_type_t::Type) -> usize {
unsafe { dnnl_data_type_size(ty) }
}
Expand Down Expand Up @@ -90,13 +93,32 @@ impl Memory {
buffer: &mut Buffer,
) -> Result<Self, DnnlError> {
let mut handle = std::ptr::null_mut::<dnnl_memory>();
let status = unsafe {
dnnl_memory_create(
&mut handle,
desc.handle,
engine.handle,
buffer.as_ptr() as *mut c_void,
)

let status = match engine.get_kind() {
Ok(dnnl_engine_kind_t::dnnl_cpu) => unsafe {
dnnl_memory_create(
&mut handle,
desc.handle,
engine.handle,
buffer.as_ptr() as *mut c_void,
)
},
Ok(dnnl_engine_kind_t::dnnl_gpu) => {
if DNNL_GPU_RUNTIME == DNNL_RUNTIME_SYCL {
todo!("Add SYCL interop")
} else if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL {
todo!("Add OCL interop")
} else {
todo!("Return Error for lack of a GPU Runtime")
}
}
Ok(dnnl_engine_kind_t::dnnl_any_engine) => {
todo!("Add DNNL ANY interop")
}
Ok(_) => {
panic!("Unexpected engine kind type type")
}
Err(e) => return Err(e),
};

if status == dnnl_status_t::dnnl_success {
Expand Down

0 comments on commit 30aff54

Please sign in to comment.