From f5ae543a8e0d6a1c7a80708cbbcf3968624b1993 Mon Sep 17 00:00:00 2001 From: Boyd Johnson Date: Fri, 10 Jan 2025 23:15:11 +0000 Subject: [PATCH 1/2] Add primitives mod, that is one directory closer to src --- benches/binary.rs | 7 ++---- src/lib.rs | 1 + src/primitive.rs | 5 ++--- src/primitive/config.rs | 9 -------- src/primitive/descriptor.rs | 7 +++--- src/primitives.rs | 8 +++++++ .../config => primitives}/au_gru.rs | 5 ++--- .../config => primitives}/batch_norm.rs | 6 ++--- .../config => primitives}/binary.rs | 6 ++--- .../config => primitives}/eltwise.rs | 5 ++--- .../config => primitives}/inner_product.rs | 6 ++--- .../config => primitives}/matmul.rs | 5 ++--- src/{primitive/config => primitives}/prelu.rs | 15 +++++-------- .../config => primitives}/reduction.rs | 6 ++--- tests/test_smoke.rs | 22 +++++++++---------- 15 files changed, 49 insertions(+), 64 deletions(-) create mode 100644 src/primitives.rs rename src/{primitive/config => primitives}/au_gru.rs (96%) rename src/{primitive/config => primitives}/batch_norm.rs (90%) rename src/{primitive/config => primitives}/binary.rs (92%) rename src/{primitive/config => primitives}/eltwise.rs (97%) rename src/{primitive/config => primitives}/inner_product.rs (95%) rename src/{primitive/config => primitives}/matmul.rs (90%) rename src/{primitive/config => primitives}/prelu.rs (79%) rename src/{primitive/config => primitives}/reduction.rs (92%) diff --git a/benches/binary.rs b/benches/binary.rs index 7421520..c0f6949 100644 --- a/benches/binary.rs +++ b/benches/binary.rs @@ -12,11 +12,8 @@ use { format_tag::x, Memory, }, - primitive::{ - attributes::PrimitiveAttributes, - config::binary::{Binary, ForwardBinary, ForwardBinaryConfig}, - ExecArg, Primitive, PropForwardInference, - }, + primitive::{attributes::PrimitiveAttributes, ExecArg, Primitive, PropForwardInference}, + primitives::binary::{Binary, ForwardBinary, ForwardBinaryConfig}, set_primitive_cache_capacity, stream::Stream, }, diff --git a/src/lib.rs b/src/lib.rs index d6f722f..1bc7cef 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,6 +2,7 @@ pub mod engine; pub mod error; pub mod memory; pub mod primitive; +pub mod primitives; pub mod stream; use error::DnnlError; diff --git a/src/primitive.rs b/src/primitive.rs index 2c89246..e3b6e97 100644 --- a/src/primitive.rs +++ b/src/primitive.rs @@ -125,10 +125,9 @@ impl Primitive { /// engine::Engine, /// memory::{descriptor::MemoryDescriptor, format_tag::x}, /// primitive::{ - /// attributes::PrimitiveAttributes, - /// config::binary::{ForwardBinary, ForwardBinaryConfig}, - /// Forward, Primitive, PropForwardInference, + /// attributes::PrimitiveAttributes, Forward, Primitive, PropForwardInference, /// }, + /// primitives::binary::{ForwardBinary, ForwardBinaryConfig}, /// }, /// onednnl_sys::{dnnl_alg_kind_t, dnnl_data_type_t::dnnl_f32}, /// }; diff --git a/src/primitive/config.rs b/src/primitive/config.rs index b132876..ccb4750 100644 --- a/src/primitive/config.rs +++ b/src/primitive/config.rs @@ -4,15 +4,6 @@ use { std::sync::Arc, }; -pub mod au_gru; -pub mod batch_norm; -pub mod binary; -pub mod eltwise; -pub mod inner_product; -pub mod matmul; -pub mod prelu; -pub mod reduction; - pub trait PrimitiveConfig<'a, D: Direction, P: PropType> { fn create_primitive_desc(&self, engine: Arc) -> Result; } diff --git a/src/primitive/descriptor.rs b/src/primitive/descriptor.rs index 59c3e75..e4cbf52 100644 --- a/src/primitive/descriptor.rs +++ b/src/primitive/descriptor.rs @@ -20,11 +20,10 @@ impl PrimitiveDescriptor { /// engine::Engine, /// memory::{descriptor::MemoryDescriptor, format_tag::x}, /// primitive::{ - /// attributes::PrimitiveAttributes, - /// config::binary::{ForwardBinary, ForwardBinaryConfig}, - /// descriptor::PrimitiveDescriptor, - /// Forward, PropForwardInference, + /// attributes::PrimitiveAttributes, descriptor::PrimitiveDescriptor, Forward, + /// PropForwardInference, /// }, + /// primitives::binary::{ForwardBinary, ForwardBinaryConfig}, /// }, /// onednnl_sys::{dnnl_alg_kind_t, dnnl_data_type_t::dnnl_f32}, /// }; diff --git a/src/primitives.rs b/src/primitives.rs new file mode 100644 index 0000000..5dd9ec4 --- /dev/null +++ b/src/primitives.rs @@ -0,0 +1,8 @@ +pub mod au_gru; +pub mod batch_norm; +pub mod binary; +pub mod eltwise; +pub mod inner_product; +pub mod matmul; +pub mod prelu; +pub mod reduction; diff --git a/src/primitive/config/au_gru.rs b/src/primitives/au_gru.rs similarity index 96% rename from src/primitive/config/au_gru.rs rename to src/primitives/au_gru.rs index a02c764..98938b3 100644 --- a/src/primitive/config/au_gru.rs +++ b/src/primitives/au_gru.rs @@ -1,12 +1,11 @@ use { - super::PrimitiveConfig, crate::{ engine::Engine, error::DnnlError, memory::descriptor::MemoryDescriptor, primitive::{ - attributes::PrimitiveAttributes, descriptor::PrimitiveDescriptor, Backward, Forward, - Operation, OperationType, PropType, + attributes::PrimitiveAttributes, config::PrimitiveConfig, + descriptor::PrimitiveDescriptor, Backward, Forward, Operation, OperationType, PropType, }, }, onednnl_sys::{ diff --git a/src/primitive/config/batch_norm.rs b/src/primitives/batch_norm.rs similarity index 90% rename from src/primitive/config/batch_norm.rs rename to src/primitives/batch_norm.rs index 9a793f8..4d15798 100644 --- a/src/primitive/config/batch_norm.rs +++ b/src/primitives/batch_norm.rs @@ -2,16 +2,14 @@ use { crate::{ memory::descriptor::MemoryDescriptor, primitive::{ - attributes::PrimitiveAttributes, descriptor::PrimitiveDescriptor, Forward, Operation, - OperationType, PropType, + attributes::PrimitiveAttributes, config::PrimitiveConfig, + descriptor::PrimitiveDescriptor, Forward, Operation, OperationType, PropType, }, }, onednnl_sys::{dnnl_batch_normalization_forward_primitive_desc_create, dnnl_status_t}, std::ffi::c_uint, }; -use super::PrimitiveConfig; - pub struct ForwardBatchNormConfig<'a> { src_desc: &'a MemoryDescriptor, dst_desc: &'a MemoryDescriptor, diff --git a/src/primitive/config/binary.rs b/src/primitives/binary.rs similarity index 92% rename from src/primitive/config/binary.rs rename to src/primitives/binary.rs index 739414d..060770c 100644 --- a/src/primitive/config/binary.rs +++ b/src/primitives/binary.rs @@ -1,10 +1,10 @@ use { - super::PrimitiveConfig, crate::{ memory::descriptor::MemoryDescriptor, primitive::{ - attributes::PrimitiveAttributes, descriptor::PrimitiveDescriptor, Forward, Operation, - OperationType, PropForwardInference, PropType, + attributes::PrimitiveAttributes, config::PrimitiveConfig, + descriptor::PrimitiveDescriptor, Forward, Operation, OperationType, + PropForwardInference, PropType, }, }, onednnl_sys::{dnnl_alg_kind_t, dnnl_binary_primitive_desc_create, dnnl_status_t}, diff --git a/src/primitive/config/eltwise.rs b/src/primitives/eltwise.rs similarity index 97% rename from src/primitive/config/eltwise.rs rename to src/primitives/eltwise.rs index 7edd791..e6ca73b 100644 --- a/src/primitive/config/eltwise.rs +++ b/src/primitives/eltwise.rs @@ -1,10 +1,9 @@ use { - super::PrimitiveConfig, crate::{ memory::descriptor::MemoryDescriptor, primitive::{ - attributes::PrimitiveAttributes, descriptor::PrimitiveDescriptor, Backward, Forward, - Operation, OperationType, PropType, + attributes::PrimitiveAttributes, config::PrimitiveConfig, + descriptor::PrimitiveDescriptor, Backward, Forward, Operation, OperationType, PropType, }, }, onednnl_sys::{ diff --git a/src/primitive/config/inner_product.rs b/src/primitives/inner_product.rs similarity index 95% rename from src/primitive/config/inner_product.rs rename to src/primitives/inner_product.rs index f8797a4..cd74f65 100644 --- a/src/primitive/config/inner_product.rs +++ b/src/primitives/inner_product.rs @@ -1,10 +1,10 @@ use { - super::PrimitiveConfig, crate::{ memory::descriptor::MemoryDescriptor, primitive::{ - attributes::PrimitiveAttributes, descriptor::PrimitiveDescriptor, Backward, Forward, - Operation, OperationType, PropBackwardData, PropBackwardWeights, PropType, + attributes::PrimitiveAttributes, config::PrimitiveConfig, + descriptor::PrimitiveDescriptor, Backward, Forward, Operation, OperationType, + PropBackwardData, PropBackwardWeights, PropType, }, }, onednnl_sys::{ diff --git a/src/primitive/config/matmul.rs b/src/primitives/matmul.rs similarity index 90% rename from src/primitive/config/matmul.rs rename to src/primitives/matmul.rs index 15e822d..75010ec 100644 --- a/src/primitive/config/matmul.rs +++ b/src/primitives/matmul.rs @@ -1,10 +1,9 @@ use { - super::PrimitiveConfig, crate::{ memory::descriptor::MemoryDescriptor, primitive::{ - attributes::PrimitiveAttributes, descriptor::PrimitiveDescriptor, Forward, Operation, - OperationType, PropType, + attributes::PrimitiveAttributes, config::PrimitiveConfig, + descriptor::PrimitiveDescriptor, Forward, Operation, OperationType, PropType, }, }, onednnl_sys::{dnnl_matmul_primitive_desc_create, dnnl_status_t}, diff --git a/src/primitive/config/prelu.rs b/src/primitives/prelu.rs similarity index 79% rename from src/primitive/config/prelu.rs rename to src/primitives/prelu.rs index 3bd0155..331f1d5 100644 --- a/src/primitive/config/prelu.rs +++ b/src/primitives/prelu.rs @@ -1,12 +1,9 @@ -use { - super::PrimitiveConfig, - crate::{ - memory::descriptor::MemoryDescriptor, - onednnl_sys::{dnnl_prelu_forward_primitive_desc_create, dnnl_status_t}, - primitive::{ - attributes::PrimitiveAttributes, descriptor::PrimitiveDescriptor, Forward, Operation, - OperationType, PropType, - }, +use crate::{ + memory::descriptor::MemoryDescriptor, + onednnl_sys::{dnnl_prelu_forward_primitive_desc_create, dnnl_status_t}, + primitive::{ + attributes::PrimitiveAttributes, config::PrimitiveConfig, descriptor::PrimitiveDescriptor, + Forward, Operation, OperationType, PropType, }, }; diff --git a/src/primitive/config/reduction.rs b/src/primitives/reduction.rs similarity index 92% rename from src/primitive/config/reduction.rs rename to src/primitives/reduction.rs index 557dbfc..96704f3 100644 --- a/src/primitive/config/reduction.rs +++ b/src/primitives/reduction.rs @@ -1,10 +1,10 @@ use { - super::PrimitiveConfig, crate::{ memory::descriptor::MemoryDescriptor, primitive::{ - attributes::PrimitiveAttributes, descriptor::PrimitiveDescriptor, Forward, Operation, - OperationType, PropForwardInference, + attributes::PrimitiveAttributes, config::PrimitiveConfig, + descriptor::PrimitiveDescriptor, Forward, Operation, OperationType, + PropForwardInference, }, }, onednnl_sys::{dnnl_alg_kind_t, dnnl_reduction_primitive_desc_create, dnnl_status_t}, diff --git a/tests/test_smoke.rs b/tests/test_smoke.rs index 051f801..45d88cb 100644 --- a/tests/test_smoke.rs +++ b/tests/test_smoke.rs @@ -12,18 +12,16 @@ use { Memory, }, primitive::{ - attributes::PrimitiveAttributes, - config::{ - binary::{Binary, ForwardBinary, ForwardBinaryConfig}, - eltwise::{ - BackwardEltwise, BackwardEltwiseConfig, ForwardEltwise, ForwardEltwiseConfig, - Unary, - }, - matmul::{ForwardMatMul, ForwardMatMulConfig}, - reduction::{ForwardReduction, ForwardReductionConfig, Reduction}, + attributes::PrimitiveAttributes, ExecArg, Primitive, PropBackward, PropBackwardData, + PropBackwardWeights, PropForwardInference, PropForwardTraining, + }, + primitives::{ + binary::{Binary, ForwardBinary, ForwardBinaryConfig}, + eltwise::{ + BackwardEltwise, BackwardEltwiseConfig, ForwardEltwise, ForwardEltwiseConfig, Unary, }, - ExecArg, Primitive, PropBackward, PropBackwardData, PropBackwardWeights, - PropForwardInference, PropForwardTraining, + matmul::{ForwardMatMul, ForwardMatMulConfig}, + reduction::{ForwardReduction, ForwardReductionConfig, Reduction}, }, stream::Stream, }, @@ -465,7 +463,7 @@ fn test_inner_product_nchw_to_nc_backprop() { DNNL_ARG_BIAS, DNNL_ARG_DIFF_BIAS, DNNL_ARG_DIFF_DST, DNNL_ARG_DIFF_SRC, DNNL_ARG_DIFF_WEIGHTS, DNNL_ARG_DST, DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, }, - primitive::config::inner_product::{ + primitives::inner_product::{ BackwardDataInnerProduct, BackwardDataInnerProductConfig, BackwardWeightsInnerProduct, BackwardWeightsInnerProductConfig, ForwardInnerProduct, ForwardInnerProductConfig, }, From bb116237352503131f282108efea0f7881bf7e51 Mon Sep 17 00:00:00 2001 From: Boyd Johnson Date: Sat, 11 Jan 2025 15:40:53 +0000 Subject: [PATCH 2/2] Break tests out of one file --- tests/test_inner_product.rs | 282 ++++++++++++++++++++++ tests/test_relu.rs | 186 +++++++++++++++ tests/test_smoke.rs | 451 +----------------------------------- 3 files changed, 471 insertions(+), 448 deletions(-) create mode 100644 tests/test_inner_product.rs create mode 100644 tests/test_relu.rs diff --git a/tests/test_inner_product.rs b/tests/test_inner_product.rs new file mode 100644 index 0000000..9ae08b2 --- /dev/null +++ b/tests/test_inner_product.rs @@ -0,0 +1,282 @@ +#[test] +fn test_inner_product_nchw_to_nc_backprop() { + use onednnl::{ + engine::Engine, + memory::{ + buffer::AlignedBuffer, + descriptor::{new_plain_descriptor, DataType}, + Memory, + }, + onednnl_sys::{ + DNNL_ARG_BIAS, DNNL_ARG_DIFF_BIAS, DNNL_ARG_DIFF_DST, DNNL_ARG_DIFF_SRC, + DNNL_ARG_DIFF_WEIGHTS, DNNL_ARG_DST, DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, + }, + primitive::{ + attributes::PrimitiveAttributes, ExecArg, Primitive, PropBackwardData, + PropBackwardWeights, PropForwardTraining, + }, + primitives::inner_product::{ + BackwardDataInnerProduct, BackwardDataInnerProductConfig, BackwardWeightsInnerProduct, + BackwardWeightsInnerProductConfig, ForwardInnerProduct, ForwardInnerProductConfig, + }, + stream::Stream, + }; + + // 1. Create an engine (CPU) + let engine = Engine::new(Engine::CPU, 0).unwrap(); + let stream = Stream::new(engine.clone()).unwrap(); + + // --------------------------------------------------- + // 2. Prepare input shapes/dimensions + // + // As in the C++ example: + // N = 3, IC = 3, IH = 227, IW = 227, OC = 96 + // We'll do an inner product from [N, IC, IH, IW] => [N, OC] + // Weights = [OC, IC, IH, IW] + // Bias = [OC] + // + let n: i64 = 15; + let ic: i64 = 3; + let ih: i64 = 227; + let iw: i64 = 227; + let oc: i64 = 96; + + let src_dims = [n, ic, ih, iw]; // shape [3, 3, 227, 227] + let weights_dims = [oc, ic, ih, iw]; // shape [96, 3, 227, 227] + let bias_dims = [oc]; // shape [96] + let dst_dims = [n, oc]; // shape [3, 96] + + // 2a. Create memory descriptors (plain / row-major) + let src_md = new_plain_descriptor(4, src_dims.to_vec(), DataType::F32); + let weights_md = new_plain_descriptor(4, weights_dims.to_vec(), DataType::F32); + let bias_md = new_plain_descriptor(1, bias_dims.to_vec(), DataType::F32); + let dst_md = new_plain_descriptor(2, dst_dims.to_vec(), DataType::F32); + + // 2b. Allocate some input data (all same, just for demonstration). + let src_len = (n * ic * ih * iw) as usize; + let weights_len = (oc * ic * ih * iw) as usize; + let bias_len = oc as usize; + let dst_len = (n * oc) as usize; + + let src_data = vec![0.5_f32; src_len]; + let weights_data = vec![0.1_f32; weights_len]; + let bias_data = vec![0.0_f32; bias_len]; + let dst_data = vec![0.0_f32; dst_len]; // Will hold forward output + + // Wrap them in user buffers + let src_buf = AlignedBuffer::new(&src_data).unwrap(); + let src_mem = + Memory::new_with_user_buffer(engine.clone(), src_md.clone_desc().unwrap(), src_buf) + .unwrap(); + + let weights_buf = AlignedBuffer::new(&weights_data).unwrap(); + let weights_mem = Memory::new_with_user_buffer( + engine.clone(), + weights_md.clone_desc().unwrap(), + weights_buf, + ) + .unwrap(); + + let bias_buf = AlignedBuffer::new(&bias_data).unwrap(); + let bias_mem = + Memory::new_with_user_buffer(engine.clone(), bias_md.clone_desc().unwrap(), bias_buf) + .unwrap(); + + let dst_buf = AlignedBuffer::new(&dst_data).unwrap(); + let dst_mem = + Memory::new_with_user_buffer(engine.clone(), dst_md.clone_desc().unwrap(), dst_buf) + .unwrap(); + + // --------------------------------------------------- + // 3. Forward Inner Product + let fwd_config = ForwardInnerProductConfig { + src_desc: &src_md, + weights_desc: &weights_md, + bias_desc: &bias_md, + dst_desc: &dst_md, + attr: &PrimitiveAttributes::new().unwrap(), + }; + + // 3a. Create the forward primitive + let fwd_prim = Primitive::new::<_, PropForwardTraining, ForwardInnerProduct<_>>( + fwd_config, + engine.clone(), + ) + .unwrap(); + + // 3b. Execute forward + fwd_prim + .execute( + &stream, + vec![ + ExecArg { + index: DNNL_ARG_SRC as i32, + mem: &src_mem, + }, + ExecArg { + index: DNNL_ARG_WEIGHTS as i32, + mem: &weights_mem, + }, + ExecArg { + index: DNNL_ARG_BIAS as i32, + mem: &bias_mem, + }, + ExecArg { + index: DNNL_ARG_DST as i32, + mem: &dst_mem, + }, + ], + ) + .unwrap(); + stream.wait().unwrap(); + + // 3c. Print a few forward outputs + let forward_result = dst_mem.to_vec().unwrap(); + println!("\n== Forward Pass =="); + println!("Forward output shape = [{}, {}]", n, oc); + println!( + "First few elements: {:?}", + &forward_result[..8.min(forward_result.len())] + ); + + // --------------------------------------------------- + // 4. Backward Weights: compute gradient w.r.t. weights and bias + // + // We'll define diff_dst as shape = [N, OC], typically the gradient + // from the next layer. For demonstration, fill with 1.0. + let diff_dst_data = vec![1.0_f32; dst_len]; + let diff_dst_buf = AlignedBuffer::new(&diff_dst_data).unwrap(); + let diff_dst_mem = + Memory::new_with_user_buffer(engine.clone(), dst_md.clone_desc().unwrap(), diff_dst_buf) + .unwrap(); + + // We'll store diff_weights in a new user buffer + let diff_weights_buf = AlignedBuffer::zeroed(weights_len).unwrap(); + let diff_weights_mem = Memory::new_with_user_buffer( + engine.clone(), + weights_md.clone_desc().unwrap(), + diff_weights_buf, + ) + .unwrap(); + + // We'll store diff_bias in a new user buffer + let diff_bias_buf = AlignedBuffer::zeroed(bias_len).unwrap(); + let diff_bias_mem = + Memory::new_with_user_buffer(engine.clone(), bias_md.clone_desc().unwrap(), diff_bias_buf) + .unwrap(); + + let bwd_weights_config = BackwardWeightsInnerProductConfig { + src_desc: &src_md, + diff_weights_desc: &weights_md, + diff_bias_desc: &bias_md, + diff_dst_desc: &dst_md, + hint_fwd_pd: &fwd_prim.desc, // from the forward primitive + attr: &PrimitiveAttributes::new().unwrap(), + }; + + // 4a. Create backward-weights primitive + let bwd_weights_prim = Primitive::new::<_, PropBackwardWeights, BackwardWeightsInnerProduct>( + bwd_weights_config, + engine.clone(), + ) + .unwrap(); + + // 4b. Execute backward-weights + bwd_weights_prim + .execute( + &stream, + vec![ + ExecArg { + index: DNNL_ARG_SRC as i32, + mem: &src_mem, + }, + ExecArg { + index: DNNL_ARG_DIFF_DST as i32, + mem: &diff_dst_mem, + }, + ExecArg { + index: DNNL_ARG_DIFF_WEIGHTS as i32, + mem: &diff_weights_mem, + }, + ExecArg { + index: DNNL_ARG_DIFF_BIAS as i32, + mem: &diff_bias_mem, + }, + ], + ) + .unwrap(); + stream.wait().unwrap(); + + // 4c. Print a few backward-weights outputs + let diff_weights_result = diff_weights_mem.to_vec().unwrap(); + let diff_bias_result = diff_bias_mem.to_vec().unwrap(); + println!("\n== Backward Weights =="); + println!( + "diff_weights: First few elements = {:?}", + &diff_weights_result[..8.min(diff_weights_result.len())] + ); + println!( + "diff_bias: First few elements = {:?}", + &diff_bias_result[..8.min(diff_bias_result.len())] + ); + + // --------------------------------------------------- + // 5. Backward Data: compute gradient w.r.t. src + // + // We'll produce diff_src from: + // - diff_dst + the original weights. + // The shape is the same as src_dims: [N, IC, IH, IW]. + let diff_src_buf = AlignedBuffer::zeroed(src_len).unwrap(); + let diff_src_mem = + Memory::new_with_user_buffer(engine.clone(), src_md.clone_desc().unwrap(), diff_src_buf) + .unwrap(); + + let bwd_data_config = BackwardDataInnerProductConfig { + diff_src_desc: &src_md, + weights_desc: &weights_md, + diff_dst_desc: &dst_md, + hint_fwd_pd: &fwd_prim.desc, // from forward pass + attr: &PrimitiveAttributes::new().unwrap(), + }; + + // 5a. Create backward-data primitive + let bwd_data_prim = Primitive::new::<_, PropBackwardData, BackwardDataInnerProduct>( + bwd_data_config, + engine.clone(), + ) + .unwrap(); + + // 5b. Execute backward-data + bwd_data_prim + .execute( + &stream, + vec![ + ExecArg { + index: DNNL_ARG_DIFF_DST as i32, + mem: &diff_dst_mem, + }, + ExecArg { + index: DNNL_ARG_WEIGHTS as i32, + mem: &weights_mem, + }, + ExecArg { + index: DNNL_ARG_DIFF_SRC as i32, + mem: &diff_src_mem, + }, + ], + ) + .unwrap(); + stream.wait().unwrap(); + + // 5c. Print a few backward-data outputs + let diff_src_result = diff_src_mem.to_vec().unwrap(); + println!("\n== Backward Data =="); + println!( + "diff_src shape = [N, IC, IH, IW] = [{}, {}, {}, {}]", + n, ic, ih, iw + ); + println!( + "diff_src: First few elements = {:?}", + &diff_src_result[..8.min(diff_src_result.len())] + ); +} diff --git a/tests/test_relu.rs b/tests/test_relu.rs new file mode 100644 index 0000000..1cbe659 --- /dev/null +++ b/tests/test_relu.rs @@ -0,0 +1,186 @@ +use onednnl::{ + engine::Engine, + memory::{ + buffer::AlignedBuffer, + data_type_size, + descriptor::{new_plain_descriptor, DataType}, + Memory, + }, + onednnl_sys::{DNNL_ARG_DIFF_DST, DNNL_ARG_DIFF_SRC, DNNL_ARG_DST, DNNL_ARG_SRC}, + primitive::{ + attributes::PrimitiveAttributes, ExecArg, Primitive, PropBackward, PropForwardTraining, + }, + primitives::eltwise::{ + BackwardEltwise, BackwardEltwiseConfig, ForwardEltwise, ForwardEltwiseConfig, Unary, + }, + stream::Stream, +}; + +#[test] +fn test_relu_forward_backward() { + // 1. Create an engine (CPU in this example) + let engine = Engine::new(Engine::CPU, 0).unwrap(); + + // --------------------------------------------------- + // 2. Prepare input data (shape = [2, 3]) + // We'll intentionally include negative values to test ReLU clamping at 0. + let src_data: Vec = vec![-1.0f32, 2.0, -3.0, 4.0, 0.0, 5.0]; + let dims = [2, 3]; + + // 2a. Create a memory descriptor for src + let src_md = new_plain_descriptor(2, dims.to_vec(), DataType::F32); + + let dst_md = new_plain_descriptor(2, dims.to_vec(), DataType::F32); + + let forward_config = ForwardEltwiseConfig { + alg_kind: Unary::RELU, // ReLU forward + src_desc: &src_md, + dst_desc: &dst_md, + alpha: 0.0, + beta: 0.0, + attr: &PrimitiveAttributes::new().unwrap(), // no special attributes + }; + + // 3b. Create the forward primitive + let fwd_prim = + Primitive::new::<_, PropForwardTraining, ForwardEltwise<_>>(forward_config, engine.clone()) + .unwrap(); + + // 3c. Allocate memory for the forward result + + let a_buffer = + AlignedBuffer::zeroed(dst_md.get_size() / data_type_size(DataType::F32)).unwrap(); + + let dst_mem = Memory::new_with_user_buffer(engine.clone(), dst_md, a_buffer).unwrap(); + + let buffer = AlignedBuffer::new(&src_data).unwrap(); + + let src_mem = Memory::new_with_user_buffer(engine.clone(), src_md, buffer).unwrap(); + + let stream = Stream::new(engine.clone()).unwrap(); + + // 3d. Execute forward ReLU + fwd_prim + .execute( + &stream, + vec![ + ExecArg { + index: DNNL_ARG_SRC as i32, + mem: &src_mem, + }, + ExecArg { + index: DNNL_ARG_DST as i32, + mem: &dst_mem, + }, + ], + ) + .unwrap(); + + stream.wait().unwrap(); + + // --------------------------------------------------- + // 4. Validate Forward Output + // + // ReLU(x) = max(0, x). So for: + // -1.0 -> 0.0 + // 2.0 -> 2.0 + // -3.0 -> 0.0 + // 4.0 -> 4.0 + // 0.0 -> 0.0 + // 5.0 -> 5.0 + let forward_result = dst_mem.to_vec().unwrap(); + + let expected_forward: Vec = vec![0.0, 2.0, 0.0, 4.0, 0.0, 5.0]; + assert_eq!( + forward_result, expected_forward, + "Forward ReLU output mismatch" + ); + + // --------------------------------------------------- + // 5. Backward ReLU Configuration + // + // We'll define "diff_dst" as if the gradient from the next layer is all 1.0: + // shape = [2, 3], so all ones => [1,1,1,1,1,1]. + let diff_dst_data = AlignedBuffer::new(&vec![1.0; src_data.len()]).unwrap(); + let diff_dst_md = src_mem.desc.clone_desc().unwrap(); + + // We'll store the result of the backward pass (the gradient w.r.t src) in diff_src + let diff_src_md = src_mem.desc.clone_desc().unwrap(); + + // We also need a "forward hint descriptor", from the forward pass + let forward_hint_desc = fwd_prim.desc.handle; // The C-level primitive_desc handle + + let bwd_config = BackwardEltwiseConfig { + alg_kind: Unary::RELU_USE_DST_FOR_BWD, + diff_src_desc: &diff_src_md, + diff_dest_desc: &diff_dst_md, + data_desc: &dst_mem.desc, // "data_desc" is typically the forward data or forward dst + alpha: 0.0, + beta: 0.0, + forward_hint_desc, + attr: &PrimitiveAttributes::new().unwrap(), + }; + + // 5b. Create the backward primitive + let bwd_prim = Primitive::new::<_, PropBackward, BackwardEltwise>( + bwd_config, + engine.clone(), + ) + .unwrap(); + + let diff_dst_mem = + Memory::new_with_user_buffer(engine.clone(), diff_dst_md, diff_dst_data).unwrap(); + + let a_buffer = + AlignedBuffer::zeroed(diff_src_md.get_size() / data_type_size(DataType::F32)).unwrap(); + + let diff_src_mem = Memory::new_with_user_buffer(engine.clone(), diff_src_md, a_buffer).unwrap(); + + // 5c. Execute backward ReLU + // + // We'll pass: + // - "diff_dst" as input (DNNL_ARG_DIFF_DST), + // - "dst" from forward pass if using *_USE_DST_FOR_BWD variant + // - "diff_src" as output. + bwd_prim + .execute( + &stream, + vec![ + // "diff_dst" as input gradient + ExecArg { + index: DNNL_ARG_DIFF_DST as i32, + mem: &diff_dst_mem, + }, + // "dst" from the forward pass if using *_USE_DST_FOR_BWD + ExecArg { + index: DNNL_ARG_DST as i32, + mem: &dst_mem, + }, + // "diff_src" as output gradient (w.r.t. src) + ExecArg { + index: DNNL_ARG_DIFF_SRC as i32, + mem: &diff_src_mem, + }, + ], + ) + .unwrap(); + + stream.wait().unwrap(); + + // --------------------------------------------------- + // 6. Validate Backward Gradient + // + // The rule for ReLU backward: + // dX = dY if Y > 0, else 0 ( for standard ReLU ). + // + // Our forward output was [0,2,0,4,0,5]. + // The "diff_dst" is all 1.0 => [1,1,1,1,1,1]. + // => diff_src = [0,1,0,1,0,1]. + let backward_result = diff_src_mem.to_vec().unwrap(); + + let expected_backward = vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0]; + assert_eq!( + backward_result, expected_backward, + "Backward ReLU gradient mismatch" + ); +} diff --git a/tests/test_smoke.rs b/tests/test_smoke.rs index 45d88cb..c863a68 100644 --- a/tests/test_smoke.rs +++ b/tests/test_smoke.rs @@ -5,28 +5,24 @@ use { buffer::AlignedBuffer, data_type_size, descriptor::{ - new_plain_descriptor, DataType, DataTypeQuery, DimsQuery, MemoryDescriptor, + DataType, DataTypeQuery, DimsQuery, MemoryDescriptor, NDimsQuery, }, format_tag::{abc, abcd, x}, Memory, }, primitive::{ - attributes::PrimitiveAttributes, ExecArg, Primitive, PropBackward, PropBackwardData, - PropBackwardWeights, PropForwardInference, PropForwardTraining, + attributes::PrimitiveAttributes, ExecArg, Primitive, PropForwardInference, }, primitives::{ binary::{Binary, ForwardBinary, ForwardBinaryConfig}, - eltwise::{ - BackwardEltwise, BackwardEltwiseConfig, ForwardEltwise, ForwardEltwiseConfig, Unary, - }, matmul::{ForwardMatMul, ForwardMatMulConfig}, reduction::{ForwardReduction, ForwardReductionConfig, Reduction}, }, stream::Stream, }, onednnl_sys::{ - dnnl_data_type_t::dnnl_f32, DNNL_ARG_BIAS, DNNL_ARG_DIFF_DST, DNNL_ARG_DIFF_SRC, + dnnl_data_type_t::dnnl_f32, DNNL_ARG_BIAS, DNNL_ARG_DST, DNNL_ARG_SRC, DNNL_ARG_SRC_0, DNNL_ARG_SRC_1, DNNL_ARG_WEIGHTS, }, std::sync::Arc, @@ -286,444 +282,3 @@ pub fn test_reduction_smoke() { assert_eq!(dst_memory.to_vec(), Ok(vec![6.0])); } - -#[test] -fn test_relu_forward_backward() { - // 1. Create an engine (CPU in this example) - let engine = Engine::new(Engine::CPU, 0).unwrap(); - - // --------------------------------------------------- - // 2. Prepare input data (shape = [2, 3]) - // We'll intentionally include negative values to test ReLU clamping at 0. - let src_data: Vec = vec![-1.0f32, 2.0, -3.0, 4.0, 0.0, 5.0]; - let dims = [2, 3]; - - // 2a. Create a memory descriptor for src - let src_md = new_plain_descriptor(2, dims.to_vec(), DataType::F32); - - let dst_md = new_plain_descriptor(2, dims.to_vec(), DataType::F32); - - let forward_config = ForwardEltwiseConfig { - alg_kind: Unary::RELU, // ReLU forward - src_desc: &src_md, - dst_desc: &dst_md, - alpha: 0.0, - beta: 0.0, - attr: &PrimitiveAttributes::new().unwrap(), // no special attributes - }; - - // 3b. Create the forward primitive - let fwd_prim = - Primitive::new::<_, PropForwardTraining, ForwardEltwise<_>>(forward_config, engine.clone()) - .unwrap(); - - // 3c. Allocate memory for the forward result - - let a_buffer = - AlignedBuffer::zeroed(dst_md.get_size() / data_type_size(DataType::F32)).unwrap(); - - let dst_mem = Memory::new_with_user_buffer(engine.clone(), dst_md, a_buffer).unwrap(); - - let buffer = AlignedBuffer::new(&src_data).unwrap(); - - let src_mem = Memory::new_with_user_buffer(engine.clone(), src_md, buffer).unwrap(); - - let stream = Stream::new(engine.clone()).unwrap(); - - // 3d. Execute forward ReLU - fwd_prim - .execute( - &stream, - vec![ - ExecArg { - index: DNNL_ARG_SRC as i32, - mem: &src_mem, - }, - ExecArg { - index: DNNL_ARG_DST as i32, - mem: &dst_mem, - }, - ], - ) - .unwrap(); - - stream.wait().unwrap(); - - // --------------------------------------------------- - // 4. Validate Forward Output - // - // ReLU(x) = max(0, x). So for: - // -1.0 -> 0.0 - // 2.0 -> 2.0 - // -3.0 -> 0.0 - // 4.0 -> 4.0 - // 0.0 -> 0.0 - // 5.0 -> 5.0 - let forward_result = dst_mem.to_vec().unwrap(); - - let expected_forward: Vec = vec![0.0, 2.0, 0.0, 4.0, 0.0, 5.0]; - assert_eq!( - forward_result, expected_forward, - "Forward ReLU output mismatch" - ); - - // --------------------------------------------------- - // 5. Backward ReLU Configuration - // - // We'll define "diff_dst" as if the gradient from the next layer is all 1.0: - // shape = [2, 3], so all ones => [1,1,1,1,1,1]. - let diff_dst_data = AlignedBuffer::new(&vec![1.0; src_data.len()]).unwrap(); - let diff_dst_md = src_mem.desc.clone_desc().unwrap(); - - // We'll store the result of the backward pass (the gradient w.r.t src) in diff_src - let diff_src_md = src_mem.desc.clone_desc().unwrap(); - - // We also need a "forward hint descriptor", from the forward pass - let forward_hint_desc = fwd_prim.desc.handle; // The C-level primitive_desc handle - - let bwd_config = BackwardEltwiseConfig { - alg_kind: Unary::RELU_USE_DST_FOR_BWD, - diff_src_desc: &diff_src_md, - diff_dest_desc: &diff_dst_md, - data_desc: &dst_mem.desc, // "data_desc" is typically the forward data or forward dst - alpha: 0.0, - beta: 0.0, - forward_hint_desc, - attr: &PrimitiveAttributes::new().unwrap(), - }; - - // 5b. Create the backward primitive - let bwd_prim = Primitive::new::<_, PropBackward, BackwardEltwise>( - bwd_config, - engine.clone(), - ) - .unwrap(); - - let diff_dst_mem = - Memory::new_with_user_buffer(engine.clone(), diff_dst_md, diff_dst_data).unwrap(); - - let a_buffer = - AlignedBuffer::zeroed(diff_src_md.get_size() / data_type_size(DataType::F32)).unwrap(); - - let diff_src_mem = Memory::new_with_user_buffer(engine.clone(), diff_src_md, a_buffer).unwrap(); - - // 5c. Execute backward ReLU - // - // We'll pass: - // - "diff_dst" as input (DNNL_ARG_DIFF_DST), - // - "dst" from forward pass if using *_USE_DST_FOR_BWD variant - // - "diff_src" as output. - bwd_prim - .execute( - &stream, - vec![ - // "diff_dst" as input gradient - ExecArg { - index: DNNL_ARG_DIFF_DST as i32, - mem: &diff_dst_mem, - }, - // "dst" from the forward pass if using *_USE_DST_FOR_BWD - ExecArg { - index: DNNL_ARG_DST as i32, - mem: &dst_mem, - }, - // "diff_src" as output gradient (w.r.t. src) - ExecArg { - index: DNNL_ARG_DIFF_SRC as i32, - mem: &diff_src_mem, - }, - ], - ) - .unwrap(); - - stream.wait().unwrap(); - - // --------------------------------------------------- - // 6. Validate Backward Gradient - // - // The rule for ReLU backward: - // dX = dY if Y > 0, else 0 ( for standard ReLU ). - // - // Our forward output was [0,2,0,4,0,5]. - // The "diff_dst" is all 1.0 => [1,1,1,1,1,1]. - // => diff_src = [0,1,0,1,0,1]. - let backward_result = diff_src_mem.to_vec().unwrap(); - - let expected_backward = vec![0.0, 1.0, 0.0, 1.0, 0.0, 1.0]; - assert_eq!( - backward_result, expected_backward, - "Backward ReLU gradient mismatch" - ); -} - -#[test] -fn test_inner_product_nchw_to_nc_backprop() { - use onednnl::{ - onednnl_sys::{ - DNNL_ARG_BIAS, DNNL_ARG_DIFF_BIAS, DNNL_ARG_DIFF_DST, DNNL_ARG_DIFF_SRC, - DNNL_ARG_DIFF_WEIGHTS, DNNL_ARG_DST, DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, - }, - primitives::inner_product::{ - BackwardDataInnerProduct, BackwardDataInnerProductConfig, BackwardWeightsInnerProduct, - BackwardWeightsInnerProductConfig, ForwardInnerProduct, ForwardInnerProductConfig, - }, - }; - - // 1. Create an engine (CPU) - let engine = Engine::new(Engine::CPU, 0).unwrap(); - let stream = Stream::new(engine.clone()).unwrap(); - - // --------------------------------------------------- - // 2. Prepare input shapes/dimensions - // - // As in the C++ example: - // N = 3, IC = 3, IH = 227, IW = 227, OC = 96 - // We'll do an inner product from [N, IC, IH, IW] => [N, OC] - // Weights = [OC, IC, IH, IW] - // Bias = [OC] - // - let n: i64 = 15; - let ic: i64 = 3; - let ih: i64 = 227; - let iw: i64 = 227; - let oc: i64 = 96; - - let src_dims = [n, ic, ih, iw]; // shape [3, 3, 227, 227] - let weights_dims = [oc, ic, ih, iw]; // shape [96, 3, 227, 227] - let bias_dims = [oc]; // shape [96] - let dst_dims = [n, oc]; // shape [3, 96] - - // 2a. Create memory descriptors (plain / row-major) - let src_md = new_plain_descriptor(4, src_dims.to_vec(), DataType::F32); - let weights_md = new_plain_descriptor(4, weights_dims.to_vec(), DataType::F32); - let bias_md = new_plain_descriptor(1, bias_dims.to_vec(), DataType::F32); - let dst_md = new_plain_descriptor(2, dst_dims.to_vec(), DataType::F32); - - // 2b. Allocate some input data (all same, just for demonstration). - let src_len = (n * ic * ih * iw) as usize; - let weights_len = (oc * ic * ih * iw) as usize; - let bias_len = oc as usize; - let dst_len = (n * oc) as usize; - - let src_data = vec![0.5_f32; src_len]; - let weights_data = vec![0.1_f32; weights_len]; - let bias_data = vec![0.0_f32; bias_len]; - let dst_data = vec![0.0_f32; dst_len]; // Will hold forward output - - // Wrap them in user buffers - let src_buf = AlignedBuffer::new(&src_data).unwrap(); - let src_mem = - Memory::new_with_user_buffer(engine.clone(), src_md.clone_desc().unwrap(), src_buf) - .unwrap(); - - let weights_buf = AlignedBuffer::new(&weights_data).unwrap(); - let weights_mem = Memory::new_with_user_buffer( - engine.clone(), - weights_md.clone_desc().unwrap(), - weights_buf, - ) - .unwrap(); - - let bias_buf = AlignedBuffer::new(&bias_data).unwrap(); - let bias_mem = - Memory::new_with_user_buffer(engine.clone(), bias_md.clone_desc().unwrap(), bias_buf) - .unwrap(); - - let dst_buf = AlignedBuffer::new(&dst_data).unwrap(); - let dst_mem = - Memory::new_with_user_buffer(engine.clone(), dst_md.clone_desc().unwrap(), dst_buf) - .unwrap(); - - // --------------------------------------------------- - // 3. Forward Inner Product - let fwd_config = ForwardInnerProductConfig { - src_desc: &src_md, - weights_desc: &weights_md, - bias_desc: &bias_md, - dst_desc: &dst_md, - attr: &PrimitiveAttributes::new().unwrap(), - }; - - // 3a. Create the forward primitive - let fwd_prim = Primitive::new::<_, PropForwardTraining, ForwardInnerProduct<_>>( - fwd_config, - engine.clone(), - ) - .unwrap(); - - // 3b. Execute forward - fwd_prim - .execute( - &stream, - vec![ - ExecArg { - index: DNNL_ARG_SRC as i32, - mem: &src_mem, - }, - ExecArg { - index: DNNL_ARG_WEIGHTS as i32, - mem: &weights_mem, - }, - ExecArg { - index: DNNL_ARG_BIAS as i32, - mem: &bias_mem, - }, - ExecArg { - index: DNNL_ARG_DST as i32, - mem: &dst_mem, - }, - ], - ) - .unwrap(); - stream.wait().unwrap(); - - // 3c. Print a few forward outputs - let forward_result = dst_mem.to_vec().unwrap(); - println!("\n== Forward Pass =="); - println!("Forward output shape = [{}, {}]", n, oc); - println!( - "First few elements: {:?}", - &forward_result[..8.min(forward_result.len())] - ); - - // --------------------------------------------------- - // 4. Backward Weights: compute gradient w.r.t. weights and bias - // - // We'll define diff_dst as shape = [N, OC], typically the gradient - // from the next layer. For demonstration, fill with 1.0. - let diff_dst_data = vec![1.0_f32; dst_len]; - let diff_dst_buf = AlignedBuffer::new(&diff_dst_data).unwrap(); - let diff_dst_mem = - Memory::new_with_user_buffer(engine.clone(), dst_md.clone_desc().unwrap(), diff_dst_buf) - .unwrap(); - - // We'll store diff_weights in a new user buffer - let diff_weights_buf = AlignedBuffer::zeroed(weights_len).unwrap(); - let diff_weights_mem = Memory::new_with_user_buffer( - engine.clone(), - weights_md.clone_desc().unwrap(), - diff_weights_buf, - ) - .unwrap(); - - // We'll store diff_bias in a new user buffer - let diff_bias_buf = AlignedBuffer::zeroed(bias_len).unwrap(); - let diff_bias_mem = - Memory::new_with_user_buffer(engine.clone(), bias_md.clone_desc().unwrap(), diff_bias_buf) - .unwrap(); - - let bwd_weights_config = BackwardWeightsInnerProductConfig { - src_desc: &src_md, - diff_weights_desc: &weights_md, - diff_bias_desc: &bias_md, - diff_dst_desc: &dst_md, - hint_fwd_pd: &fwd_prim.desc, // from the forward primitive - attr: &PrimitiveAttributes::new().unwrap(), - }; - - // 4a. Create backward-weights primitive - let bwd_weights_prim = Primitive::new::<_, PropBackwardWeights, BackwardWeightsInnerProduct>( - bwd_weights_config, - engine.clone(), - ) - .unwrap(); - - // 4b. Execute backward-weights - bwd_weights_prim - .execute( - &stream, - vec![ - ExecArg { - index: DNNL_ARG_SRC as i32, - mem: &src_mem, - }, - ExecArg { - index: DNNL_ARG_DIFF_DST as i32, - mem: &diff_dst_mem, - }, - ExecArg { - index: DNNL_ARG_DIFF_WEIGHTS as i32, - mem: &diff_weights_mem, - }, - ExecArg { - index: DNNL_ARG_DIFF_BIAS as i32, - mem: &diff_bias_mem, - }, - ], - ) - .unwrap(); - stream.wait().unwrap(); - - // 4c. Print a few backward-weights outputs - let diff_weights_result = diff_weights_mem.to_vec().unwrap(); - let diff_bias_result = diff_bias_mem.to_vec().unwrap(); - println!("\n== Backward Weights =="); - println!( - "diff_weights: First few elements = {:?}", - &diff_weights_result[..8.min(diff_weights_result.len())] - ); - println!( - "diff_bias: First few elements = {:?}", - &diff_bias_result[..8.min(diff_bias_result.len())] - ); - - // --------------------------------------------------- - // 5. Backward Data: compute gradient w.r.t. src - // - // We'll produce diff_src from: - // - diff_dst + the original weights. - // The shape is the same as src_dims: [N, IC, IH, IW]. - let diff_src_buf = AlignedBuffer::zeroed(src_len).unwrap(); - let diff_src_mem = - Memory::new_with_user_buffer(engine.clone(), src_md.clone_desc().unwrap(), diff_src_buf) - .unwrap(); - - let bwd_data_config = BackwardDataInnerProductConfig { - diff_src_desc: &src_md, - weights_desc: &weights_md, - diff_dst_desc: &dst_md, - hint_fwd_pd: &fwd_prim.desc, // from forward pass - attr: &PrimitiveAttributes::new().unwrap(), - }; - - // 5a. Create backward-data primitive - let bwd_data_prim = Primitive::new::<_, PropBackwardData, BackwardDataInnerProduct>( - bwd_data_config, - engine.clone(), - ) - .unwrap(); - - // 5b. Execute backward-data - bwd_data_prim - .execute( - &stream, - vec![ - ExecArg { - index: DNNL_ARG_DIFF_DST as i32, - mem: &diff_dst_mem, - }, - ExecArg { - index: DNNL_ARG_WEIGHTS as i32, - mem: &weights_mem, - }, - ExecArg { - index: DNNL_ARG_DIFF_SRC as i32, - mem: &diff_src_mem, - }, - ], - ) - .unwrap(); - stream.wait().unwrap(); - - // 5c. Print a few backward-data outputs - let diff_src_result = diff_src_mem.to_vec().unwrap(); - println!("\n== Backward Data =="); - println!( - "diff_src shape = [N, IC, IH, IW] = [{}, {}, {}, {}]", - n, ic, ih, iw - ); - println!( - "diff_src: First few elements = {:?}", - &diff_src_result[..8.min(diff_src_result.len())] - ); -}