Skip to content

Commit

Permalink
Add BackwardEltwiseConfig and test for relu forward and backward
Browse files Browse the repository at this point in the history
  • Loading branch information
boydjohnson committed Jan 4, 2025
1 parent 65e2646 commit 4f39ed2
Show file tree
Hide file tree
Showing 4 changed files with 234 additions and 11 deletions.
15 changes: 12 additions & 3 deletions src/primitive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use {
au_gru::{BackwardAuGruConfig, ForwardAuGruConfig},
batch_norm::ForwardBatchNormConfig,
binary::ForwardBinaryConfig,
eltwise::ForwardEltwiseConfig,
eltwise::{BackwardEltwiseConfig, ForwardEltwiseConfig},
matmul::ForwardMatMulConfig,
reduction::ForwardReductionConfig,
PrimitiveConfig,
Expand Down Expand Up @@ -98,7 +98,7 @@ impl PropType<Forward> for PropForwardInference {
}

impl PropType<Forward> for PropForwardTraining {
const KIND: dnnl_prop_kind_t::Type = dnnl_prop_kind_t::dnnl_forward_inference;
const KIND: dnnl_prop_kind_t::Type = dnnl_prop_kind_t::dnnl_forward_training;
}

impl PropType<Backward> for PropBackward {
Expand Down Expand Up @@ -173,8 +173,17 @@ impl<'a> Operation<'a, Forward, PropForwardInference> for ForwardReduction {
type OperationConfig = ForwardReductionConfig<'a>;
}

pub struct BackwardEltwise<T: PropType<Backward>> {
pub prop_type: T,
}

impl<'a, P: PropType<Backward>> Operation<'a, Backward, P> for BackwardEltwise<P> {
const TYPE: OperationType = OperationType::Eltwise;
type OperationConfig = BackwardEltwiseConfig<'a>;
}

pub struct Primitive {
pub(crate) handle: dnnl_primitive_t,
pub handle: dnnl_primitive_t,
pub desc: PrimitiveDescriptor,
pub engine: Arc<Engine>,
}
Expand Down
45 changes: 43 additions & 2 deletions src/primitive/config/eltwise.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ use {
super::PrimitiveConfig,
crate::{
memory::descriptor::MemoryDescriptor,
primitive::{descriptor::PrimitiveDescriptor, Forward, PropType},
primitive::{descriptor::PrimitiveDescriptor, Backward, Forward, PropType},
},
onednnl_sys::{
dnnl_alg_kind_t, dnnl_eltwise_forward_primitive_desc_create, dnnl_primitive_attr_t,
dnnl_alg_kind_t, dnnl_eltwise_backward_primitive_desc_create,
dnnl_eltwise_forward_primitive_desc_create, dnnl_primitive_attr_t, dnnl_primitive_desc_t,
dnnl_status_t,
},
};
Expand Down Expand Up @@ -86,3 +87,43 @@ impl Unary {
pub const SQUARE: dnnl_alg_kind_t::Type = dnnl_alg_kind_t::dnnl_eltwise_square;
pub const SWISH: dnnl_alg_kind_t::Type = dnnl_alg_kind_t::dnnl_eltwise_swish;
}

pub struct BackwardEltwiseConfig<'a> {
pub alg_kind: dnnl_alg_kind_t::Type,
pub diff_src_desc: &'a MemoryDescriptor,
pub diff_dest_desc: &'a MemoryDescriptor,
pub data_desc: &'a MemoryDescriptor,
pub alpha: f32,
pub beta: f32,
pub forward_hint_desc: dnnl_primitive_desc_t,
pub attr: dnnl_primitive_attr_t,
}

impl<'a, P: PropType<Backward>> PrimitiveConfig<'a, Backward, P> for BackwardEltwiseConfig<'a> {
fn create_primitive_desc(
&self,
engine: std::sync::Arc<crate::engine::Engine>,
) -> Result<crate::primitive::descriptor::PrimitiveDescriptor, crate::error::DnnlError> {
let mut handle = std::ptr::null_mut();
let status = unsafe {
dnnl_eltwise_backward_primitive_desc_create(
&mut handle,
engine.handle,
self.alg_kind,
self.diff_src_desc.handle,
self.diff_dest_desc.handle,
self.data_desc.handle,
self.alpha,
self.beta,
self.forward_hint_desc,
self.attr,
)
};

if status == dnnl_status_t::dnnl_success {
Ok(PrimitiveDescriptor { handle })
} else {
Err(status.into())
}
}
}
2 changes: 1 addition & 1 deletion src/primitive/descriptor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use {
};

pub struct PrimitiveDescriptor {
pub(crate) handle: dnnl_primitive_desc_t,
pub handle: dnnl_primitive_desc_t,
}

impl PrimitiveDescriptor {
Expand Down
183 changes: 178 additions & 5 deletions tests/test_smoke.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,28 @@ use {
memory::{
buffer::AlignedBuffer,
data_type_size,
descriptor::{DataType, DataTypeQuery, DimsQuery, MemoryDescriptor, NDimsQuery},
descriptor::{
new_plain_descriptor, DataType, DataTypeQuery, DimsQuery, MemoryDescriptor,
NDimsQuery,
},
format_tag::{abc, abcd, x},
Memory,
},
primitive::{
config::{
binary::{Binary, ForwardBinaryConfig},
eltwise::{BackwardEltwiseConfig, ForwardEltwiseConfig, Unary},
matmul::ForwardMatMulConfig,
reduction::{ForwardReductionConfig, Reduction},
},
ExecArg, ForwardBinary, ForwardMatMul, ForwardReduction, Primitive,
PropForwardInference,
BackwardEltwise, ExecArg, ForwardBinary, ForwardEltwise, ForwardMatMul,
ForwardReduction, Primitive, PropBackward, PropForwardInference, PropForwardTraining,
},
stream::Stream,
},
onednnl_sys::{
dnnl_data_type_t::dnnl_f32, DNNL_ARG_BIAS, DNNL_ARG_DST, DNNL_ARG_SRC, DNNL_ARG_SRC_0,
DNNL_ARG_SRC_1, DNNL_ARG_WEIGHTS,
dnnl_data_type_t::dnnl_f32, DNNL_ARG_BIAS, DNNL_ARG_DIFF_DST, DNNL_ARG_DIFF_SRC,
DNNL_ARG_DST, DNNL_ARG_SRC, DNNL_ARG_SRC_0, DNNL_ARG_SRC_1, DNNL_ARG_WEIGHTS,
},
std::sync::Arc,
};
Expand Down Expand Up @@ -284,3 +288,172 @@ pub fn test_reduction_smoke() {

assert_eq!(dst_memory.to_vec(), Ok(vec![6.0]));
}

#[test]
fn test_relu_forward_backward() {
// 1. Create an engine (CPU in this example)
let engine = Engine::new(Engine::CPU, 0).unwrap();

// ---------------------------------------------------
// 2. Prepare input data (shape = [2, 3])
// We'll intentionally include negative values to test ReLU clamping at 0.
let src_data: Vec<f32> = vec![-1.0f32, 2.0, -3.0, 4.0, 0.0, 5.0];
let dims = [2, 3];

// 2a. Create a memory descriptor for src
let src_md = new_plain_descriptor(2, dims.to_vec(), DataType::F32);

let dst_md = new_plain_descriptor(2, dims.to_vec(), DataType::F32);

let forward_config = ForwardEltwiseConfig {
alg_kind: Unary::RELU, // ReLU forward
src_desc: &src_md,
dst_desc: &dst_md,
alpha: 0.0,
beta: 0.0,
attr: std::ptr::null_mut(), // no special attributes
};

// 3b. Create the forward primitive
let fwd_prim =
Primitive::new::<_, PropForwardTraining, ForwardEltwise<_>>(forward_config, engine.clone())
.unwrap();

// 3c. Allocate memory for the forward result

let a_buffer =
AlignedBuffer::zeroed(dst_md.get_size() / data_type_size(DataType::F32)).unwrap();

let dst_mem = Memory::new_with_user_buffer(engine.clone(), dst_md, a_buffer).unwrap();

let buffer = AlignedBuffer::new(&src_data).unwrap();

let src_mem = Memory::new_with_user_buffer(engine.clone(), src_md, buffer).unwrap();

let stream = Stream::new(engine.clone()).unwrap();

// 3d. Execute forward ReLU
fwd_prim
.execute(
&stream,
vec![
ExecArg {
index: DNNL_ARG_SRC as i32,
mem: &src_mem,
},
ExecArg {
index: DNNL_ARG_DST as i32,
mem: &dst_mem,
},
],
)
.unwrap();

stream.wait().unwrap();

// ---------------------------------------------------
// 4. Validate Forward Output
//
// ReLU(x) = max(0, x). So for:
// -1.0 -> 0.0
// 2.0 -> 2.0
// -3.0 -> 0.0
// 4.0 -> 4.0
// 0.0 -> 0.0
// 5.0 -> 5.0
let forward_result = dst_mem.to_vec().unwrap();

let expected_forward: Vec<f32> = vec![0.0, 2.0, 0.0, 4.0, 0.0, 5.0];
assert_eq!(
forward_result, expected_forward,
"Forward ReLU output mismatch"
);

// ---------------------------------------------------
// 5. Backward ReLU Configuration
//
// We'll define "diff_dst" as if the gradient from the next layer is all 1.0:
// shape = [2, 3], so all ones => [1,1,1,1,1,1].
let diff_dst_data = AlignedBuffer::new(&vec![1.0; src_data.len()]).unwrap();
let diff_dst_md = src_mem.desc.clone_desc().unwrap();

// We'll store the result of the backward pass (the gradient w.r.t src) in diff_src
let diff_src_md = src_mem.desc.clone_desc().unwrap();

// We also need a "forward hint descriptor", from the forward pass
let forward_hint_desc = fwd_prim.desc.handle; // The C-level primitive_desc handle

let bwd_config = BackwardEltwiseConfig {
alg_kind: Unary::RELU_USE_DST_FOR_BWD,
diff_src_desc: &diff_src_md,
diff_dest_desc: &diff_dst_md,
data_desc: &dst_mem.desc, // "data_desc" is typically the forward data or forward dst
alpha: 0.0,
beta: 0.0,
forward_hint_desc,
attr: std::ptr::null_mut(),
};

// 5b. Create the backward primitive
let bwd_prim = Primitive::new::<_, PropBackward, BackwardEltwise<PropBackward>>(
bwd_config,
engine.clone(),
)
.unwrap();

let diff_dst_mem =
Memory::new_with_user_buffer(engine.clone(), diff_dst_md, diff_dst_data).unwrap();

let a_buffer =
AlignedBuffer::zeroed(diff_src_md.get_size() / data_type_size(DataType::F32)).unwrap();

let diff_src_mem = Memory::new_with_user_buffer(engine.clone(), diff_src_md, a_buffer).unwrap();

// 5c. Execute backward ReLU
//
// We'll pass:
// - "diff_dst" as input (DNNL_ARG_DIFF_DST),
// - "dst" from forward pass if using *_USE_DST_FOR_BWD variant
// - "diff_src" as output.
bwd_prim
.execute(
&stream,
vec![
// "diff_dst" as input gradient
ExecArg {
index: DNNL_ARG_DIFF_DST as i32,
mem: &diff_dst_mem,
},
// "dst" from the forward pass if using *_USE_DST_FOR_BWD
ExecArg {
index: DNNL_ARG_DST as i32,
mem: &dst_mem,
},
// "diff_src" as output gradient (w.r.t. src)
ExecArg {
index: DNNL_ARG_DIFF_SRC as i32,
mem: &diff_src_mem,
},
],
)
.unwrap();

stream.wait().unwrap();

// ---------------------------------------------------
// 6. Validate Backward Gradient
//
// The rule for ReLU backward:
// dX = dY if Y > 0, else 0 ( for standard ReLU ).
//
// Our forward output was [0,2,0,4,0,5].
// The "diff_dst" is all 1.0 => [1,1,1,1,1,1].
// => diff_src = [0,1,0,1,0,1].
let backward_result = diff_src_mem.to_vec().unwrap();

let expected_backward = vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0];
assert_eq!(
backward_result, expected_backward,
"Backward ReLU gradient mismatch"
);
}

0 comments on commit 4f39ed2

Please sign in to comment.