Skip to content

Commit

Permalink
add div, logical and pow
Browse files Browse the repository at this point in the history
  • Loading branch information
Chamberlain0w0 committed Jan 8, 2024
1 parent 80772cf commit 2f2d922
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 29 deletions.
100 changes: 71 additions & 29 deletions src/04kernel/src/kernels/simple_binary/binary_cnnl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace refactor::kernel {

auto K::build(Op op, Tensor const &a, Tensor const &b, Tensor const &c) noexcept -> KernelBox {
static const std::unordered_set<Op>
ARTHIMETIC{Op::Add, Op::Sub, Op::Mul};
ARTHIMETIC{Op::Add, Op::Sub, Op::Mul, Op::Div, Op::And, Op::Or, Op::Xor, Op::Pow};

#ifndef USE_BANG
return nullptr;
Expand Down Expand Up @@ -84,15 +84,20 @@ namespace refactor::kernel {
};
auto d = std::make_shared<Descriptors>(dataType != DT::F64);
cnnlOpTensorDesc_t cnnlOP;
cnnlLogicOp_t cnnlLogicOP;
if (opType == SimpleBinaryType::Add) {
cnnlOP = CNNL_OP_TENSOR_ADD;
} else if (opType == SimpleBinaryType::Sub) {
cnnlOP = CNNL_OP_TENSOR_ADD;
d->sub = true;
} else if (opType == SimpleBinaryType::Mul) {
cnnlOP = CNNL_OP_TENSOR_MUL;
} else {
UNREACHABLE();
} else if (opType == SimpleBinaryType::And) {
cnnlLogicOP = CNNL_LOGIC_OP_AND;
} else if (opType == SimpleBinaryType::Or) {
cnnlLogicOP = CNNL_LOGIC_OP_OR;
} else if (opType == SimpleBinaryType::Xor) {
cnnlLogicOP = CNNL_LOGIC_OP_XOR;
}

setCnnlTensor(d->aDesc, dataType, slice(aDims.data(), aDims.size()));
Expand All @@ -103,40 +108,77 @@ namespace refactor::kernel {
cnnlDataTypeConvert(d->f32 ? DT::F32 : DT::F64),
CNNL_NOT_PROPAGATE_NAN));

return [swap = aDims != cDims, d](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) {
auto cnnlGetBinaryWorkspaceSize =
(opType == SimpleBinaryType::Add || opType == SimpleBinaryType::Sub || opType == SimpleBinaryType::Mul) ? cnnlGetOpTensorWorkspaceSize
: (opType == SimpleBinaryType::Div) ? cnnlGetDivWorkspaceSize
: (opType == SimpleBinaryType::And || opType == SimpleBinaryType::Or || opType == SimpleBinaryType::Xor) ? cnnlGetLogicOpWorkspaceSize
: (opType == SimpleBinaryType::Pow) ? cnnlGetPowWorkspaceSize
: nullptr;

if (cnnlGetBinaryWorkspaceSize == nullptr) {
UNREACHABLE();
}

return [swap = aDims != cDims, d, cnnlGetBinaryWorkspaceSize, cnnlLogicOP, this](Resources &res, void *workspace, void const *const *inputs, void *const *outputs) {
auto handle = res.fetchOrStore<CnnlContext>()->handle;
// name inputs and outputs
auto a = inputs[0],
b = inputs[1];
auto c = outputs[0];
auto alphaA = d->f32
? factor<fp32_t>(1)
: factor<fp64_t>(1),
alphaB = d->f32
? factor<fp32_t>(d->sub ? -1 : 1)
: factor<fp64_t>(d->sub ? -1 : 1),
beta = d->f32
? factor<fp32_t>(0)
: factor<fp64_t>(0);
size_t workspaceSize;
if (swap) {
CNNL_ASSERT(cnnlGetOpTensorWorkspaceSize(handle, d->bDesc,
d->aDesc, d->cDesc,
&workspaceSize));
CNNL_ASSERT(cnnlOpTensor(handle, d->opDesc,
&alphaB, d->bDesc, b,
&alphaA, d->aDesc, a,
workspace, workspaceSize,
&beta, d->cDesc, c));
CNNL_ASSERT(cnnlGetBinaryWorkspaceSize(handle, d->bDesc,
d->aDesc, d->cDesc,
&workspaceSize));
} else {
CNNL_ASSERT(cnnlGetOpTensorWorkspaceSize(handle, d->aDesc,
d->bDesc, d->cDesc,
&workspaceSize));
CNNL_ASSERT(cnnlOpTensor(handle, d->opDesc,
&alphaA, d->aDesc, a,
&alphaB, d->bDesc, b,
workspace, workspaceSize,
&beta, d->cDesc, c));
CNNL_ASSERT(cnnlGetBinaryWorkspaceSize(handle, d->aDesc,
d->bDesc, d->cDesc,
&workspaceSize));
}
if (this->opType == SimpleBinaryType::Add || this->opType == SimpleBinaryType::Sub || this->opType == SimpleBinaryType::Mul) {
auto alphaA = d->f32
? factor<fp32_t>(1)
: factor<fp64_t>(1),
alphaB = d->f32
? factor<fp32_t>(d->sub ? -1 : 1)
: factor<fp64_t>(d->sub ? -1 : 1),
beta = d->f32
? factor<fp32_t>(0)
: factor<fp64_t>(0);

if (swap) {
CNNL_ASSERT(cnnlOpTensor(handle, d->opDesc,
&alphaB, d->bDesc, b,
&alphaA, d->aDesc, a,
workspace, workspaceSize,
&beta, d->cDesc, c));
} else {
CNNL_ASSERT(cnnlOpTensor(handle, d->opDesc,
&alphaA, d->aDesc, a,
&alphaB, d->bDesc, b,
workspace, workspaceSize,
&beta, d->cDesc, c));
}
} else if (this->opType == SimpleBinaryType::Div) {
CNNL_ASSERT(cnnlDiv_v2(handle,
CNNL_COMPUTATION_HIGH_PRECISION,
d->aDesc, a,
d->bDesc, b,
workspace, workspaceSize,
d->cDesc, c));
} else if (opType == SimpleBinaryType::And || opType == SimpleBinaryType::Or || opType == SimpleBinaryType::Xor) {
CNNL_ASSERT(cnnlLogicOp(handle, cnnlLogicOP,
d->aDesc, a,
d->bDesc, b,
workspace, workspaceSize,
d->cDesc, c));
} else if (opType == SimpleBinaryType::Pow) {
CNNL_ASSERT(cnnlPow(handle,
CNNL_COMPUTATION_HIGH_PRECISION,
d->aDesc, a,
d->bDesc, b,
workspace, workspaceSize,
d->cDesc, c));
}
};
}
Expand Down
20 changes: 20 additions & 0 deletions src/04kernel/test/kernels/simple_binary/test_binary_cnnl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,26 @@ TEST(kernel, BinaryCnnlSub) {
testBinaryCnnl(SimpleBinaryType::Sub, Shape{10, 20, 30, 40}, Shape{10, 20, 30, 40}, Shape{10, 20, 30, 40});
}

TEST(kernel, BinaryCnnlDiv) {
testBinaryCnnl(SimpleBinaryType::Div, Shape{10, 20, 30, 40}, Shape{10, 20, 30, 40}, Shape{10, 20, 30, 40});
}

// TEST(kernel, BinaryCnnlAnd) {
// testBinaryCnnl(SimpleBinaryType::And, Shape{10, 20, 30, 40}, Shape{10, 20, 30, 40}, Shape{10, 20, 30, 40});
// }

// TEST(kernel, BinaryCnnlOr) {
// testBinaryCnnl(SimpleBinaryType::Or, Shape{10, 20, 30, 40}, Shape{10, 20, 30, 40}, Shape{10, 20, 30, 40});
// }

// TEST(kernel, BinaryCnnlXor) {
// testBinaryCnnl(SimpleBinaryType::Xor, Shape{10, 20, 30, 40}, Shape{10, 20, 30, 40}, Shape{10, 20, 30, 40});
// }

TEST(kernel, BinaryCnnlPow) {
testBinaryCnnl(SimpleBinaryType::Pow, Shape{10, 20, 30, 40}, Shape{10, 20, 30, 40}, Shape{10, 20, 30, 40});
}

TEST(kernel, BinaryCnnlBroadcast) {
testBinaryCnnl(SimpleBinaryType::Add, Shape{3, 4, 5, 6}, Shape{}, Shape{3, 4, 5, 6});
}
Expand Down

0 comments on commit 2f2d922

Please sign in to comment.