Skip to content

Commit

Permalink
fix: fix mod kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
kilinchange authored and bitzyz committed Jan 17, 2024
1 parent 3998833 commit 1ce6d81
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 10 deletions.
8 changes: 4 additions & 4 deletions src/04kernel/src/kernels/simple_binary/cpu_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,11 @@ namespace refactor::kernel {
switch (dataType.internal) {
CASE_DT(std::fmod(a, b), F32);
CASE_DT(a % b, U8);
CASE_DT(a % b < 0 ? (a % b + b) : (a % b), I8);
CASE_DT(static_cast<int8_t>(std::fmod(a, b)), I8);
CASE_DT(a % b, U16);
CASE_DT(a % b < 0 ? (a % b + b) : (a % b), I16);
CASE_DT(a % b < 0 ? (a % b + b) : (a % b), I32);
CASE_DT(a % b < 0 ? (a % b + b) : (a % b), I64);
CASE_DT(static_cast<int16_t>(std::fmod(a, b)), I16);
CASE_DT(static_cast<int32_t>(std::fmod(a, b)), I32);
CASE_DT(static_cast<int64_t>(std::fmod(a, b)), I64);
CASE_DT(std::fmod(a, b), F64);
CASE_DT(a % b, U32);
CASE_DT(a % b, U64);
Expand Down
12 changes: 8 additions & 4 deletions src/04kernel/src/kernels/simple_binary/cuda_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -158,14 +158,18 @@ extern "C" __global__ void kernel(
case SimpleBinaryType::Fmod:
switch (dt) {
case DataType::U8:
case DataType::I8:
case DataType::U16:
case DataType::U32:
case DataType::U64:
return "a % b";
case DataType::I8:
return "static_cast<char>(fmodf(a, b))";
case DataType::I16:
return "static_cast<short>(fmodf(a, b))";
case DataType::I32:
return "static_cast<int>(fmodf(a, b))";
case DataType::I64:
case DataType::U32:
case DataType::U64:
return "a % b < 0 ? (a % b + b) : (a % b)";
return "static_cast<long long>(fmodf(a, b))";
case DataType::F32:
return "fmodf(a, b)";
case DataType::FP16:
Expand Down
9 changes: 8 additions & 1 deletion src/04kernel/src/kernels/simple_unary/cpu_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,14 @@ namespace refactor::kernel {
return static_cast<T>(std::tanh(static_cast<M>(x)));
}
template<class T> auto hardswishFun(T x) noexcept -> T {
return x * (std::max(0., std::min(1., 1.f / 6 * x + 0.5)));
auto res = x / 6.f;
if (res >= 0.5) {
return x;
} else if (res <= -0.5) {
return 0.;
} else {
return x * (x / 6.f + 0.5);
}
}
auto copyForUnsigned(size_t n) noexcept -> Routine {
return [n](runtime::Resources &, void *workspace, void const *const *inputs, void *const *outputs) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ TEST(kernel, BinaryCpu) {
testBinaryCPU(SimpleBinaryType::Mul, [](float a, float b) { return a * b; });
testBinaryCPU(SimpleBinaryType::Div, [](float a, float b) { return a / b; });
testModCPU(SimpleBinaryType::Mod, [](int a, int b) { return a % b; });
testFmodWithI32CPU(SimpleBinaryType::Fmod, [](int a, int b) { return a % b < 0 ? (a % b + b) : (a % b); });
testFmodWithI32CPU(SimpleBinaryType::Fmod, [](int a, int b) { return static_cast<int32_t>(std::fmod(a, b)); });
testBinaryCPU(SimpleBinaryType::Fmod, [](float a, float b) { return std::fmod(a, b); });
}

Expand Down
1 change: 1 addition & 0 deletions src/07onnx/src/operators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ namespace refactor::onnx {
REGISTER(And , SimpleBinary );
REGISTER(Or , SimpleBinary );
REGISTER(Xor , SimpleBinary );
REGISTER(Mod , SimpleBinary );
REGISTER(Abs , SimpleUnary );
REGISTER(Acos , SimpleUnary );
REGISTER(Acosh , SimpleUnary );
Expand Down

0 comments on commit 1ce6d81

Please sign in to comment.