Skip to content

Commit

Permalink
[fix] fix conv when relu/relu6 appear at the same time
Browse files Browse the repository at this point in the history
  • Loading branch information
Alcanderian committed Apr 6, 2024
1 parent 3a4006a commit ab0bb94
Show file tree
Hide file tree
Showing 10 changed files with 20 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,8 @@ ppl::common::RetCode conv2d_n16cx_direct_fp32_avx512_executor::execute()
if (is_last_ic) {
if (with_relu) {
kernel_flags |= KERNEL_FLAG_RELU();
} else if (with_relu6) {
}
if (with_relu6) {
kernel_flags |= KERNEL_FLAG_RELU6();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,8 @@ ppl::common::RetCode conv2d_n16cx_gemm_direct_fp32_avx512_executor::execute()
if (is_last_ic) {
if (with_relu) {
kernel_flags |= conv2d_n16cx_gemm_direct_kernel_fp32_avx512::flag::RELU;
} else if (with_relu6) {
}
if (with_relu6) {
kernel_flags |= conv2d_n16cx_gemm_direct_kernel_fp32_avx512::flag::RELU6;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,8 @@ ppl::common::RetCode conv2d_im2col_gemm_fp32_fma_executor::execute()
if (is_last_k) {
if (with_relu) {
kernel_flags |= KERNEL_FLAG_RELU();
} else if (with_relu6) {
}
if (with_relu6) {
kernel_flags |= KERNEL_FLAG_RELU6();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,8 @@ ppl::common::RetCode conv2d_n16cx_direct_fp32_fma_executor::execute()
if (is_last_ic) {
if (with_relu) {
kernel_flags |= ker_flag::RELU;
} else if (with_relu6) {
}
if (with_relu6) {
kernel_flags |= ker_flag::RELU6;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,8 @@ ppl::common::RetCode conv2d_n16cx_gemm_direct_fp32_fma_executor::execute()
if (is_last_ic) {
if (with_relu) {
kernel_flags |= conv2d_n16cx_gemm_direct_kernel_fp32_fma::flag::RELU;
} else if (with_relu6) {
}
if (with_relu6) {
kernel_flags |= conv2d_n16cx_gemm_direct_kernel_fp32_fma::flag::RELU6;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -657,12 +657,12 @@ ppl::common::RetCode conv2d_depthwise_fp32_sse_executor::execute()

auto dst_trans_func = conv2d_depthwise_fp32_sse_dst_trans<0, 0>;
if (with_sum) {
dst_trans_func = conv2d_depthwise_fp32_sse_dst_trans<0, 1>;
if (with_relu) dst_trans_func = conv2d_depthwise_fp32_sse_dst_trans<1, 1>;
else if (with_relu6) dst_trans_func = conv2d_depthwise_fp32_sse_dst_trans<6, 1>;
else dst_trans_func = conv2d_depthwise_fp32_sse_dst_trans<0, 1>;
if (with_relu6) dst_trans_func = conv2d_depthwise_fp32_sse_dst_trans<6, 1>;
} else {
if (with_relu) dst_trans_func = conv2d_depthwise_fp32_sse_dst_trans<1, 0>;
else if (with_relu6) dst_trans_func = conv2d_depthwise_fp32_sse_dst_trans<6, 0>;
if (with_relu6) dst_trans_func = conv2d_depthwise_fp32_sse_dst_trans<6, 0>;
}

PRAGMA_OMP_PARALLEL_FOR()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,8 @@ ppl::common::RetCode conv2d_im2col_gemm_fp32_sse_executor::execute()
if (is_last_k) {
if (with_relu) {
kp.pick<int64_t>(conv_gemm_kernel_fp32_sse::param_def::FLAGS_IDX) |= conv_gemm_kernel_fp32_sse::flag::RELU;
} else if (with_relu6) {
}
if (with_relu6) {
kp.pick<int64_t>(conv_gemm_kernel_fp32_sse::param_def::FLAGS_IDX) |= conv_gemm_kernel_fp32_sse::flag::RELU6;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,8 @@ ppl::common::RetCode conv2d_n8cx_direct_fp32_sse_executor::execute()
if (is_last_ic) {
if (with_relu) {
kernel_flags |= KERNEL_FLAG_RELU();
} else if (with_relu6) {
}
if (with_relu6) {
kernel_flags |= KERNEL_FLAG_RELU6();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,8 @@ ppl::common::RetCode conv2d_n8cx_gemm_direct_fp32_sse_executor::execute()
if (is_last_ic) {
if (with_relu) {
kernel_flags |= KERNEL_FLAG_RELU();
} else if (with_relu6) {
}
if (with_relu6) {
kernel_flags |= KERNEL_FLAG_RELU6();
}
}
Expand Down
1 change: 1 addition & 0 deletions test/test_conv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,7 @@ for (int64_t lcfg = 0; lcfg < Flag_loop_cfg; ++lcfg) {
if (Flag_relu == 1) {
param.fuse_flag |= ppl::kernel::x86::conv_fuse_flag::RELU;
} else if (Flag_relu == 6) {
param.fuse_flag |= ppl::kernel::x86::conv_fuse_flag::RELU;
param.fuse_flag |= ppl::kernel::x86::conv_fuse_flag::RELU6;
}

Expand Down

0 comments on commit ab0bb94

Please sign in to comment.