Skip to content

Commit

Permalink
Add ForwardPreluConfig
Browse files Browse the repository at this point in the history
  • Loading branch information
boydjohnson committed Jan 10, 2025
1 parent 3abdb31 commit 2597863
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 2 deletions.
4 changes: 2 additions & 2 deletions TODOS.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
| `lbr_gru` |||||
| `lrn` |||||
| `lstm` |||||
| `matmul` || |||
| `matmul` || |||
| `pooling` |||||
| `prelu` | ||||
| `prelu` | ||||
| `reduction` |||||
| `reorder` |||||
| `resampling` |||||
Expand Down
1 change: 1 addition & 0 deletions src/primitive/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ pub mod binary;
pub mod eltwise;
pub mod inner_product;
pub mod matmul;
pub mod prelu;
pub mod reduction;

pub trait PrimitiveConfig<'a, D: Direction, P: PropType<D>> {
Expand Down
53 changes: 53 additions & 0 deletions src/primitive/config/prelu.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
use {
super::PrimitiveConfig,
crate::{
memory::descriptor::MemoryDescriptor,
onednnl_sys::{dnnl_prelu_forward_primitive_desc_create, dnnl_status_t},
primitive::{
attributes::PrimitiveAttributes, descriptor::PrimitiveDescriptor, Forward, Operation,
OperationType, PropType,
},
},
};

pub struct ForwardPreluConfig<'a> {
pub src_desc: &'a MemoryDescriptor,
weights_desc: &'a MemoryDescriptor,
dst_desc: &'a MemoryDescriptor,
attr: &'a PrimitiveAttributes,
}

impl<'a, P: PropType<Forward>> PrimitiveConfig<'a, Forward, P> for ForwardPreluConfig<'a> {
fn create_primitive_desc(
&self,
engine: std::sync::Arc<crate::engine::Engine>,
) -> Result<crate::primitive::descriptor::PrimitiveDescriptor, crate::error::DnnlError> {
let mut handle = std::ptr::null_mut();

let status = unsafe {
dnnl_prelu_forward_primitive_desc_create(
&mut handle,
engine.handle,
P::KIND,
self.src_desc.handle,
self.weights_desc.handle,
self.dst_desc.handle,
self.attr.handle,
)
};
if status == dnnl_status_t::dnnl_success {
Ok(PrimitiveDescriptor { handle })
} else {
Err(status.into())
}
}
}

pub struct ForwardPrelu<P: PropType<Forward>> {
pub prop_type: P,
}

impl<'a, P: PropType<Forward>> Operation<'a, Forward, P> for ForwardPrelu<P> {
const TYPE: crate::primitive::OperationType = OperationType::PRelu;
type OperationConfig = ForwardPreluConfig<'a>;
}

0 comments on commit 2597863

Please sign in to comment.