diff --git a/src/04kernel/cuda/include/kernel/cuda/pad.cuh b/src/04kernel/cuda/include/kernel/cuda/pad.cuh new file mode 100644 index 00000000..bc2dfb0a --- /dev/null +++ b/src/04kernel/cuda/include/kernel/cuda/pad.cuh @@ -0,0 +1,22 @@ +#ifndef KERNEL_CUDA_PAD_CUH +#define KERNEL_CUDA_PAD_CUH + +#include "threads_distributer.cuh" +#include + +namespace refactor::kernel::cuda { + + struct PadDimInfo { + unsigned int strideI, strideO, padS, dimI; + }; + + void launchPad( + KernelLaunchParameters const &, + uint8_t const *src, uint8_t const *src_const, + PadDimInfo const *dims, void *output, + unsigned int rank, + unsigned int blockSize); + +}// namespace refactor::kernel::cuda + +#endif// KERNEL_CUDA_PAD_CUH diff --git a/src/04kernel/cuda/include/kernel/cuda/slice.cuh b/src/04kernel/cuda/include/kernel/cuda/slice.cuh index 770477fd..f381a525 100644 --- a/src/04kernel/cuda/include/kernel/cuda/slice.cuh +++ b/src/04kernel/cuda/include/kernel/cuda/slice.cuh @@ -5,14 +5,14 @@ namespace refactor::kernel::cuda { - struct DimInfo { + struct SliceDimInfo { unsigned int strideO, skip; int strideI; }; void launchSlice( KernelLaunchParameters const &, - void const *src, DimInfo const *dims, void *output, + void const *src, SliceDimInfo const *dims, void *output, unsigned int rank, unsigned int blockSize); diff --git a/src/04kernel/cuda/src/pad.cu b/src/04kernel/cuda/src/pad.cu new file mode 100644 index 00000000..c4d8e420 --- /dev/null +++ b/src/04kernel/cuda/src/pad.cu @@ -0,0 +1,63 @@ +#include "kernel/cuda/pad.cuh" +#include "macro.cuh" +#include + +namespace refactor::kernel::cuda { + + __global__ static void padKernel( + unsigned long long n, + uint8_t const *__restrict__ src, + uint8_t const *__restrict__ src_const, + PadDimInfo const *__restrict__ dims, + uint8_t *__restrict__ dst, + unsigned int rank, + unsigned int blockSize) { + for (auto tid = blockIdx.x * blockDim.x + threadIdx.x, + step = blockDim.x * gridDim.x; + tid < n; + tid += step) { + long rem = tid, j = 0; + bool flag = false; + for (auto i = 0; i < rank; ++i) { + auto strideO = __ldg(&(dims[i].strideO)); + auto strideI = __ldg(&(dims[i].strideI)); + auto padS = __ldg(&(dims[i].padS)); + auto dimI = __ldg(&(dims[i].dimI)); + auto pos = rem / strideO - padS; + if (pos < 0 || pos >= dimI) { + flag = true; + break; + } + j += pos * strideI; + rem %= strideO; + } + if (flag) { + optimizedMemcpy(dst + tid * blockSize, src_const, blockSize); + } else { + optimizedMemcpy(dst + tid * blockSize, src + j * blockSize, blockSize); + } + } + } + + void launchPad( + KernelLaunchParameters const ¶ms, + uint8_t const *src, uint8_t const *src_const, + PadDimInfo const *dims, void *output, + unsigned int rank, + unsigned int blockSize) { + + padKernel<<< + params.gridSize, + params.blockSize, + 0, + reinterpret_cast(params.stream)>>>( + params.n, + src, + src_const, + dims, + reinterpret_cast(output), + rank, + blockSize); + } + +}// namespace refactor::kernel::cuda diff --git a/src/04kernel/cuda/src/slice.cu b/src/04kernel/cuda/src/slice.cu index 802b7cfe..ce092e9c 100644 --- a/src/04kernel/cuda/src/slice.cu +++ b/src/04kernel/cuda/src/slice.cu @@ -7,7 +7,7 @@ namespace refactor::kernel::cuda { __global__ static void sliceKernel( unsigned long long n, uint8_t const *__restrict__ src, - DimInfo const *__restrict__ dims, + SliceDimInfo const *__restrict__ dims, uint8_t *__restrict__ dst, unsigned int rank, unsigned int blockSize) { @@ -29,7 +29,7 @@ namespace refactor::kernel::cuda { void launchSlice( KernelLaunchParameters const ¶ms, - void const *src, DimInfo const *dims, void *output, + void const *src, SliceDimInfo const *dims, void *output, unsigned int rank, unsigned int blockSize) { sliceKernel<<< diff --git a/src/04kernel/include/kernel/attributes/pad_info.h b/src/04kernel/include/kernel/attributes/pad_info.h new file mode 100644 index 00000000..ff39f097 --- /dev/null +++ b/src/04kernel/include/kernel/attributes/pad_info.h @@ -0,0 +1,62 @@ +#ifndef KERNEL_PAD_ATTRIBUTES_H +#define KERNEL_PAD_ATTRIBUTES_H + +#include "../tensor.h" +#include "common.h" + +namespace refactor::kernel { + + struct PadType { + enum : uint8_t { + Constant, + Reflect, + Edge, + Wrap, + } type; + + constexpr PadType() noexcept + : type(Constant) {} + constexpr PadType(decltype(type) type_) noexcept + : type(type_) {} + constexpr operator decltype(type)() const noexcept { + return type; + } + constexpr std::string_view toString() const noexcept { + switch (type) { + case Constant: + return "Constant"; + case Reflect: + return "Reflect"; + case Edge: + return "Edge"; + case Wrap: + return "Wrap"; + default: + UNREACHABLE(); + } + } + }; + + namespace pad { + struct Dim { + int64_t dimI, dimO, pads; + }; + }// namespace pad + + using PadDimension = std::vector; + + struct PadInfo { + struct Dim { + dim_t strideI, strideO, padS, dimI; + }; + std::vector dims; + dim_t blockCount, blockSize; + + PadInfo(decltype(dims), dim_t, dim_t) noexcept; + PadInfo(PadDimension, Tensor const &); + void reform(dim_t) noexcept; + }; + +}// namespace refactor::kernel + +#endif// KERNEL_PAD_ATTRIBUTES_H diff --git a/src/04kernel/include/kernel/collectors/pad.h b/src/04kernel/include/kernel/collectors/pad.h new file mode 100644 index 00000000..fd7d9744 --- /dev/null +++ b/src/04kernel/include/kernel/collectors/pad.h @@ -0,0 +1,21 @@ +#ifndef KERNEL_PAD_H +#define KERNEL_PAD_H + +#include "../attributes/pad_info.h" +#include "../collector.h" + +namespace refactor::kernel { + + struct PadCollector final : public InfoCollector { + PadDimension dims; + PadType mode; + + explicit PadCollector(decltype(_target) target, PadDimension const &dims_, PadType mode_) noexcept + : InfoCollector(target), dims(std::move(dims_)), mode(mode_) {} + + std::vector + filter(TensorRefs inputs, TensorRefs outputs) const final; + }; +}// namespace refactor::kernel + +#endif// KERNEL_PAD_H diff --git a/src/04kernel/src/attributes/pad_info.cc b/src/04kernel/src/attributes/pad_info.cc new file mode 100644 index 00000000..bc830297 --- /dev/null +++ b/src/04kernel/src/attributes/pad_info.cc @@ -0,0 +1,74 @@ +#include "kernel/attributes/pad_info.h" +#include + +namespace refactor::kernel { + using PI = PadInfo; + + PI::PadInfo(decltype(dims) dims_, dim_t blockCount_, dim_t blockSize_) noexcept + : dims(std::move(dims_)), blockCount(blockCount_), blockSize(blockSize_) {} + + PI::PadInfo(PadDimension dims_, Tensor const &input) : dims{}, blockCount(1), + blockSize(input.dataType.size()) { + size_t rank = input.rank(); + ASSERT(dims_.size() == rank, "Invalid to get PadInfo."); + + size_t j = 0; + for (auto i : range0_(rank)) { + if (dims_[i].dimI != dims_[i].dimO || dims_[i].dimI != 1) { + if (j < i) { dims_[j] = dims_[i]; } + j++; + } + } + dims_.resize(rank = j); + + // 合并末尾连续维度 + for (auto i : range0_(rank).rev()) { + if (auto d = dims_[i].dimI; d == dims_[i].dimO) { + blockSize *= d; + dims_.pop_back(); + } else { + auto &dim = dims_[i]; + if (auto times = std::gcd(std::gcd(dims_[i].dimI, dims_[i].pads), dims_[i].dimO); times > 1) { + blockSize *= times; + dim.dimI /= times; + dim.dimO /= times; + dim.pads /= times; + } + break; + } + } + dims.reserve(rank = dims_.size()); + + dim_t strideI = 1, strideO = 1; + for (auto i : range0_(rank).rev()) { + auto const &dim = dims_[i]; + dims.push_back({ + strideI, + strideO, + static_cast(dim.pads), + static_cast(dim.dimI), + }); + strideI *= dim.dimI; + strideO *= dim.dimO; + } + std::reverse(dims.begin(), dims.end()); + blockCount = strideO; + } + + void PI::reform(dim_t maxblockSize) noexcept { + auto blockSize_ = std::gcd(blockSize, maxblockSize); + if (blockSize_ == blockSize) { return; } + auto t = blockSize / blockSize_; + blockCount *= t; + blockSize = blockSize_; + for (auto &d : dims) { + d.strideI *= t; + d.strideO *= t; + d.padS *= t; + d.dimI *= t; + } + dims.resize(dims.size() + 1); + dims.back() = {1, 1, 0, t}; + } + +}// namespace refactor::kernel diff --git a/src/04kernel/src/attributes/slice_info.cc b/src/04kernel/src/attributes/slice_info.cc index a3397c82..fa3039ee 100644 --- a/src/04kernel/src/attributes/slice_info.cc +++ b/src/04kernel/src/attributes/slice_info.cc @@ -46,7 +46,6 @@ namespace refactor::kernel { shape.pop_back(); dims_.pop_back(); } else { - dims.resize(rank = shape.size()); if (auto &dim = dims_[i]; dim.step == 1) { if (auto times = std::gcd(std::gcd(dim.start, dim.length), shape[i]); times > 1) { blockSize *= times; @@ -58,6 +57,7 @@ namespace refactor::kernel { break; } } + dims.resize(rank = shape.size()); dim_t strideI = 1; for (auto i : range0_(rank).rev()) { auto const &dim = dims_[i]; diff --git a/src/04kernel/src/collectors/pad.cc b/src/04kernel/src/collectors/pad.cc new file mode 100644 index 00000000..f4c995e0 --- /dev/null +++ b/src/04kernel/src/collectors/pad.cc @@ -0,0 +1,32 @@ +#include "kernel/collectors/pad.h" +#include "../kernels/pad/cpu_kernel.hh" +#include "../kernels/pad/cuda_kernel.hh" + +namespace refactor::kernel { + + std::vector + PadCollector::filter(TensorRefs inputs, TensorRefs outputs) const { + auto const &input = inputs[0]; + PadInfo info(dims, input); + auto const_value = inputs.size() >= 3 ? std::make_optional(inputs[2]) : std::nullopt; + + std::vector ans; + switch (_target) { + case decltype(_target)::Cpu: + if (auto ptr = PadCpu::build(std::move(info), mode, const_value); ptr) { + ans.emplace_back(std::move(ptr)); + } + break; + case decltype(_target)::Nvidia: + if (auto ptr = PadCuda::build(std::move(info), mode, const_value); ptr) { + ans.emplace_back(std::move(ptr)); + } + break; + default: + UNREACHABLEX(void, "Unknown target"); + } + return ans; + } + +}// namespace refactor::kernel + diff --git a/src/04kernel/src/kernels/pad/cpu_kernel.cc b/src/04kernel/src/kernels/pad/cpu_kernel.cc new file mode 100644 index 00000000..f249ec83 --- /dev/null +++ b/src/04kernel/src/kernels/pad/cpu_kernel.cc @@ -0,0 +1,66 @@ +#include "cpu_kernel.hh" +#include + +namespace refactor::kernel { + using K = PadCpu; + + K::PadCpu(PadInfo info_, PadType mode_, size_t value_) noexcept + : Kernel(), info(std::move(info_)), mode(mode_), valueLength(value_) {} + + auto K::build(PadInfo info, PadType mode, std::optional> value_) noexcept -> KernelBox { + if (mode != PadType::Constant) { + return nullptr; + } + size_t value = value_ ? value_->get().dataType.size() : 0; + return std::make_unique(std::move(info), mode, value); + } + auto K::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + + auto K::kernelTypeId() const noexcept -> size_t { + return typeId(); + } + auto K::description() const noexcept -> std::string_view { + return "Performing pad operation on generic cpu"; + } + + + auto K::lower(Resources &) const noexcept -> RoutineWorkspace { + using namespace runtime; + + return [info = this->info, value = this->valueLength](Resources &, void *workspace, void const *const *inputs, void *const *outputs) { + auto src = reinterpret_cast(inputs[0]); + auto dst = reinterpret_cast(outputs[0]); + std::vector defaultValue(info.blockSize, 0); + if (value != 0) { + auto constValue = reinterpret_cast(inputs[2]); + for (auto i : range0_(info.blockSize / value)) { + std::memcpy(defaultValue.data() + i * value, constValue, value); + } + } + std::for_each_n(std::execution::par_unseq, + natural_t(0), info.blockCount, + [=, &info](auto i) { + long rem = i, j = 0; + bool flag = false; + for (auto const &dim : info.dims) { + auto pos = rem / dim.strideO - dim.padS; + if (pos < 0 || pos >= dim.dimI) { + flag = true; + break; + } + j += pos * dim.strideI; + rem %= dim.strideO; + } + if (flag) { + std::memcpy(dst + i * info.blockSize, defaultValue.data(), info.blockSize); + } else { + std::memcpy(dst + i * info.blockSize, src + j * info.blockSize, info.blockSize); + } + }); + }; + } + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/pad/cpu_kernel.hh b/src/04kernel/src/kernels/pad/cpu_kernel.hh new file mode 100644 index 00000000..d314c520 --- /dev/null +++ b/src/04kernel/src/kernels/pad/cpu_kernel.hh @@ -0,0 +1,27 @@ +#ifndef KERNEL_PAD_CPU_KERNEL_HH +#define KERNEL_PAD_CPU_KERNEL_HH + +#include "kernel/attributes/pad_info.h" +#include "kernel/kernel.h" + +namespace refactor::kernel { + + struct PadCpu final : public Kernel { + PadInfo info; + PadType mode; + size_t valueLength; + + explicit PadCpu(PadInfo, PadType, size_t) noexcept; + + static KernelBox build(PadInfo, PadType, std::optional>) noexcept; + static size_t typeId() noexcept; + + size_t kernelTypeId() const noexcept final; + std::string_view description() const noexcept final; + RoutineWorkspace lower(Resources &) const noexcept final; + }; + +}// namespace refactor::kernel + +#endif// KERNEL_PAD_CPU_KERNEL_HH + diff --git a/src/04kernel/src/kernels/pad/cuda_kernel.cc b/src/04kernel/src/kernels/pad/cuda_kernel.cc new file mode 100644 index 00000000..5aa302d9 --- /dev/null +++ b/src/04kernel/src/kernels/pad/cuda_kernel.cc @@ -0,0 +1,31 @@ +#include "cuda_kernel.hh" + +namespace refactor::kernel { + using K = PadCuda; + + K::PadCuda(PadInfo info_, PadType mode_, size_t value_) noexcept + : Kernel(), info(std::move(info_)), mode(mode_), valueLength(value_) {} + + auto K::build(PadInfo info, PadType mode, std::optional> value_) noexcept -> KernelBox { +#ifndef USE_CUDA + return nullptr; +#endif + if (mode != PadType::Constant) { + return nullptr; + } + size_t value = value_ ? value_->get().dataType.size() : 0; + info.reform(16); + return std::make_unique(std::move(info), mode, value); + } + + auto K::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + + auto K::kernelTypeId() const noexcept -> size_t { return typeId(); } + auto K::description() const noexcept -> std::string_view { + return "Performing Pad using CUDA"; + } + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/pad/cuda_kernel.cu b/src/04kernel/src/kernels/pad/cuda_kernel.cu new file mode 100644 index 00000000..d9c909b6 --- /dev/null +++ b/src/04kernel/src/kernels/pad/cuda_kernel.cu @@ -0,0 +1,39 @@ +#include "cuda_kernel.hh" +#include "kernel/cuda/pad.cuh" +#include +#include + +namespace refactor::kernel { + using namespace runtime; + + auto PadCuda::lower(Resources &) const noexcept -> RoutineWorkspace { + thrust::host_vector dims(info.dims.size()); + std::transform(info.dims.begin(), info.dims.end(), + dims.begin(), + [](auto const &d) { + return cuda::PadDimInfo{ + d.strideI, + d.strideO, + d.padS, + d.dimI, + }; + }); + return [dims = thrust::device_vector(dims), + params = cuda::ThreadsDistributer()(info.blockCount), + blockSize = info.blockSize, + value = this->valueLength](Resources &, void *workspace, void const *const *inputs, void *const *outputs) { + auto src = reinterpret_cast(inputs[0]); + thrust::device_vector defaultValue(blockSize, 0); + if (value != 0) { + auto constValue = reinterpret_cast(inputs[2]); + for (auto i : range0_(blockSize / value)) { + cudaMemcpy(defaultValue.data().get() + i * value, constValue, value, cudaMemcpyDeviceToDevice); + } + } + cuda::launchPad(params, src, defaultValue.data().get(), dims.data().get(), outputs[0], + dims.size(), + blockSize); + }; + } + +}// namespace refactor::kernel diff --git a/src/04kernel/src/kernels/pad/cuda_kernel.hh b/src/04kernel/src/kernels/pad/cuda_kernel.hh new file mode 100644 index 00000000..b0f915a5 --- /dev/null +++ b/src/04kernel/src/kernels/pad/cuda_kernel.hh @@ -0,0 +1,27 @@ +#ifndef KERNEL_PAD_CUDA_HH +#define KERNEL_PAD_CUDA_HH + +#include "kernel/attributes/pad_info.h" +#include "kernel/collectors/pad.h" + +namespace refactor::kernel { + + struct PadCuda final : public Kernel { + PadInfo info; + PadType mode; + size_t valueLength; + + PadCuda(PadInfo, PadType, size_t) noexcept; + static KernelBox build(PadInfo, PadType, std::optional>) noexcept; + static size_t typeId() noexcept; + + size_t kernelTypeId() const noexcept final; + std::string_view description() const noexcept final; +#ifdef USE_CUDA + RoutineWorkspace lower(Resources &) const noexcept final; +#endif + }; + +}// namespace refactor::kernel + +#endif//KERNEL_PAD_CUDA_HH diff --git a/src/04kernel/src/kernels/slice/cuda_kernel.cu b/src/04kernel/src/kernels/slice/cuda_kernel.cu index a6a3037c..33029899 100644 --- a/src/04kernel/src/kernels/slice/cuda_kernel.cu +++ b/src/04kernel/src/kernels/slice/cuda_kernel.cu @@ -7,17 +7,17 @@ namespace refactor::kernel { using namespace runtime; auto SliceCuda::lower(Resources &) const noexcept -> RoutineWorkspace { - thrust::host_vector dims(info.dims.size()); + thrust::host_vector dims(info.dims.size()); std::transform(info.dims.begin(), info.dims.end(), dims.begin(), [](auto const &d) { - return cuda::DimInfo{ + return cuda::SliceDimInfo{ d.strideO, d.skip, d.strideI, }; }); - return [dims = thrust::device_vector(dims), + return [dims = thrust::device_vector(dims), params = cuda::ThreadsDistributer()(info.blockCount), blockSize = info.blockSize](Resources &, void *workspace, void const *const *inputs, void *const *outputs) { auto src = reinterpret_cast(inputs[0]); diff --git a/src/04kernel/test/kernels/pad/test_cpu.cpp b/src/04kernel/test/kernels/pad/test_cpu.cpp new file mode 100644 index 00000000..48b1bbb0 --- /dev/null +++ b/src/04kernel/test/kernels/pad/test_cpu.cpp @@ -0,0 +1,126 @@ +#include "../../../include/kernel/attributes/pad_info.h" +#include "../../../src/kernels/pad/cpu_kernel.hh" +#include +#include + +using namespace refactor; +using namespace kernel; + +TEST(kernel, PadCpu) { + // no constant_value + { + PadDimension dims{ + {2, 4, 1}, + {3, 5, 1}, + }; + // build routine + auto xTensor = Tensor::share(DataType::F32, Shape{2, 3}); + auto yTensor = Tensor::share(DataType::F32, Shape{4, 5}); + PadType mode = PadType::Constant; + auto kernel = PadCpu::build(PadInfo(dims, *xTensor), mode, std::nullopt); + ASSERT_TRUE(kernel); + auto res = runtime::Resources(); + auto routine = kernel->lower(res).routine; + // set input data + std::vector + data(xTensor->elementsSize(), 1), + result(yTensor->elementsSize()); + // inference + { + void const *inputs[]{data.data()}; + void *outputs[]{result.data()}; + routine(res, nullptr, inputs, outputs); + } + // check + std::vector output = {0., 0., 0., 0., 0., 0., 1., 1., 1., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0., 0.}; + for (auto i : range0_(result.size())) { + EXPECT_FLOAT_EQ(output[i], result[i]); + } + } + // have constant_value + { + PadDimension dims{ + {2, 4, 1}, + {3, 5, 1}, + }; + // build routine + auto t1Tensor = Tensor::share(DataType::F32, Shape{2, 3}); + auto t2Tensor = Tensor::share(DataType::I64, Shape{4}); + auto t3Tensor = Tensor::share(DataType::F32, Shape{}); + auto yTensor = Tensor::share(DataType::F32, Shape{4, 5}); + PadType type = PadType::Constant; + auto kernel = PadCpu::build(PadInfo(dims, *t1Tensor), type, std::make_optional(std::reference_wrapper(*t3Tensor))); + ASSERT_TRUE(kernel); + auto res = runtime::Resources(); + auto routine = kernel->lower(res).routine; + // set input data + std::vector + data(t1Tensor->elementsSize(), 1), + result(yTensor->elementsSize()); + std::vector constant_value(1, 1.2); + std::vector pads_value(4, 1); + // inference + { + void const *inputs[]{data.data(), pads_value.data(), constant_value.data()}; + void *outputs[]{result.data()}; + routine(res, nullptr, inputs, outputs); + } + // check + std::vector output = {1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1., 1., 1., 1.2, 1.2, 1., 1., 1., 1.2, 1.2, 1.2, 1.2, 1.2, 1.2}; + for (auto i : range0_(result.size())) { + EXPECT_FLOAT_EQ(output[i], result[i]); + } + } + { + PadDimension dims{ + {2, 4, 1}, + {3, 5, 1}, + {1, 1, 0}, + {4, 8, 2}, + }; + // build routine + auto t1Tensor = Tensor::share(DataType::F32, Shape{2, 3, 1, 4}); + auto t2Tensor = Tensor::share(DataType::I64, Shape{8}); + auto t3Tensor = Tensor::share(DataType::F32, Shape{}); + auto yTensor = Tensor::share(DataType::F32, Shape{4, 5, 1, 8}); + PadType type = PadType::Constant; + auto kernel = PadCpu::build(PadInfo(dims, *t1Tensor), type, std::make_optional(std::reference_wrapper(*t3Tensor))); + ASSERT_TRUE(kernel); + auto res = runtime::Resources(); + auto routine = kernel->lower(res).routine; + // set input data + std::vector + data(t1Tensor->elementsSize(), 1), + result(yTensor->elementsSize()); + std::vector constant_value(1, 1.2); + std::vector pads_value{1, 1, 0, 2, 1, 1, 0, 2}; + // inference + { + void const *inputs[]{data.data(), pads_value.data(), constant_value.data()}; + void *outputs[]{result.data()}; + routine(res, nullptr, inputs, outputs); + } + // check + std::vector output = {1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, + 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, + 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, + 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, + 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, + 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.0000, 1.0000, 1.0000, 1.0000, + 1.2000, 1.2000, 1.2000, 1.2000, 1.0000, 1.0000, 1.0000, 1.0000, 1.2000, + 1.2000, 1.2000, 1.2000, 1.0000, 1.0000, 1.0000, 1.0000, 1.2000, 1.2000, + 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, + 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, + 1.0000, 1.0000, 1.0000, 1.0000, 1.2000, 1.2000, 1.2000, 1.2000, 1.0000, + 1.0000, 1.0000, 1.0000, 1.2000, 1.2000, 1.2000, 1.2000, 1.0000, 1.0000, + 1.0000, 1.0000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, + 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, + 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, + 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, + 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, + 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000, 1.2000}; + for (auto i : range0_(result.size())) { + EXPECT_FLOAT_EQ(output[i], result[i]); + } + } +} diff --git a/src/04kernel/test/kernels/pad/test_cuda.cpp b/src/04kernel/test/kernels/pad/test_cuda.cpp new file mode 100644 index 00000000..4c755535 --- /dev/null +++ b/src/04kernel/test/kernels/pad/test_cuda.cpp @@ -0,0 +1,128 @@ +#ifdef USE_CUDA + +#include "../../../src/kernels/pad/cpu_kernel.hh" +#include "../../../src/kernels/pad/cuda_kernel.hh" +#include "hardware/device_manager.h" +#include + +using namespace refactor; +using namespace kernel; +using namespace hardware; + +TEST(kernel, PadCuda) { + { + PadDimension dims{ + {2, 4, 1}, + {3, 5, 1}, + {1, 1, 0}, + {4, 8, 2}, + }; + // build routine + auto t1Tensor = Tensor::share(DataType::F32, Shape{2, 3, 1, 4}); + auto t2Tensor = Tensor::share(DataType::I64, Shape{8}); + auto t3Tensor = Tensor::share(DataType::F32, Shape{}); + auto yTensor = Tensor::share(DataType::F32, Shape{4, 5, 1, 8}); + PadType type = PadType::Constant; + auto kCpu = PadCpu::build(PadInfo(dims, *t1Tensor), type, std::make_optional(std::reference_wrapper(*t3Tensor))); + auto kernel = PadCuda::build(PadInfo(dims, *t1Tensor), type, std::make_optional(std::reference_wrapper(*t3Tensor))); + ASSERT_TRUE(kernel && kCpu); + auto res = runtime::Resources(); + auto routine = kernel->lower(res).routine, + rCpu = kCpu->lower(res).routine; + // malloc + auto &dev = *device::init(Device::Type::Nvidia, 0, ""); + auto gpuIn = dev.malloc(t1Tensor->bytesSize()), + gpuIn2 = dev.malloc(t2Tensor->bytesSize()), + gpuIn3 = dev.malloc(t3Tensor->bytesSize()), + gpuOut = dev.malloc(yTensor->bytesSize()); + // put input data + std::vector data(t1Tensor->elementsSize()), + constvalue(1, 1.2f), + cpuOut(yTensor->elementsSize()); + std::vector pads{1, 1, 0, 2, 1, 1, 0, 2}; + + + for (auto i : range0_(data.size())) { data[i] = i; } + gpuIn->copyFromHost(data.data(), t1Tensor->bytesSize()); + gpuIn2->copyFromHost(pads.data(), t2Tensor->bytesSize()); + gpuIn3->copyFromHost(constvalue.data(), t3Tensor->bytesSize()); + + // inference + { + void const *inputs[]{*gpuIn, *gpuIn2, *gpuIn3}; + void *outputs[]{*gpuOut}; + routine(res, nullptr, inputs, outputs); + } + { + void const *inputs[]{data.data(), pads.data(), constvalue.data()}; + void *outputs[]{cpuOut.data()}; + rCpu(res, nullptr, inputs, outputs); + } + // take output data + std::vector result(yTensor->elementsSize()); + gpuOut->copyToHost(result.data(), yTensor->bytesSize()); + // check + for (auto i : range0_(cpuOut.size())) { + EXPECT_FLOAT_EQ(cpuOut[i], result[i]); + } + } + + { + PadDimension dims{ + {2, 2, 0}, + {3, 3, 0}, + {1, 1, 0}, + {4, 4, 0}, + }; + // build routine + auto t1Tensor = Tensor::share(DataType::F32, Shape{2, 3, 1, 4}); + auto t2Tensor = Tensor::share(DataType::I64, Shape{8}); + auto t3Tensor = Tensor::share(DataType::F32, Shape{}); + auto yTensor = Tensor::share(DataType::F32, Shape{2, 3, 1, 4}); + PadType type = PadType::Constant; + auto kCpu = PadCpu::build(PadInfo(dims, *t1Tensor), type, std::make_optional(std::reference_wrapper(*t3Tensor))); + auto kernel = PadCuda::build(PadInfo(dims, *t1Tensor), type, std::make_optional(std::reference_wrapper(*t3Tensor))); + ASSERT_TRUE(kernel && kCpu); + auto res = runtime::Resources(); + auto routine = kernel->lower(res).routine, + rCpu = kCpu->lower(res).routine; + // malloc + auto &dev = *device::init(Device::Type::Nvidia, 0, ""); + auto gpuIn = dev.malloc(t1Tensor->bytesSize()), + gpuIn2 = dev.malloc(t2Tensor->bytesSize()), + gpuIn3 = dev.malloc(t3Tensor->bytesSize()), + gpuOut = dev.malloc(yTensor->bytesSize()); + // put input data + std::vector data(t1Tensor->elementsSize()), + constvalue(1, 1.2f), + cpuOut(yTensor->elementsSize()); + std::vector pads{0, 0, 0, 0, 0, 0, 0, 0}; + + + for (auto i : range0_(data.size())) { data[i] = i; } + gpuIn->copyFromHost(data.data(), t1Tensor->bytesSize()); + gpuIn2->copyFromHost(pads.data(), t2Tensor->bytesSize()); + gpuIn3->copyFromHost(constvalue.data(), t3Tensor->bytesSize()); + + // inference + { + void const *inputs[]{*gpuIn, *gpuIn2, *gpuIn3}; + void *outputs[]{*gpuOut}; + routine(res, nullptr, inputs, outputs); + } + { + void const *inputs[]{data.data(), pads.data(), constvalue.data()}; + void *outputs[]{cpuOut.data()}; + rCpu(res, nullptr, inputs, outputs); + } + // take output data + std::vector result(yTensor->elementsSize()); + gpuOut->copyToHost(result.data(), yTensor->bytesSize()); + // check + for (auto i : range0_(cpuOut.size())) { + EXPECT_FLOAT_EQ(cpuOut[i], result[i]); + } + } +} + +#endif diff --git a/src/04kernel/test/kernels/slice/test_cuda.cpp b/src/04kernel/test/kernels/slice/test_cuda.cpp index d54938d0..7ea419e3 100644 --- a/src/04kernel/test/kernels/slice/test_cuda.cpp +++ b/src/04kernel/test/kernels/slice/test_cuda.cpp @@ -11,49 +11,96 @@ using namespace kernel; using namespace hardware; TEST(kernel, SliceCuda) { - // build routine - Dimensions dims{ - {5, -2, 3},// 7 -> {5, 3, 1} -> {108, 900, -360} - {2, 3, 2}, // 6 -> {2, 5} -> { 36, 60, 90} - {1, 1, 3}, // 5 -> {1, 2, 3} -> { 18, 6, 30} - {0, 1, 1}, // 1 -> {0} - {0, 1, 2}, // 2 -> {0, 1} - {0, 1, 3}, // 3 -> {0, 1, 2} - }; - auto input = Tensor::share(DataType::F32, Shape{7, 6, 5, 1, 2, 3}), - output = Tensor::share(DataType::F32, Shape{3, 2, 3, 1, 2, 3}); - SliceInfo info(dims, *input); - auto kernel = SliceCuda::build(info); - auto kCpu = SliceCpu::build(info); - ASSERT_TRUE(kernel && kCpu); - auto res = runtime::Resources(); - auto routine = kernel->lower(res).routine; - auto rCpu = kCpu->lower(res).routine; - // malloc - auto &dev = *device::init(Device::Type::Nvidia, 0, ""); - auto gpuIn = dev.malloc(input->bytesSize()), - gpuOut = dev.malloc(output->bytesSize()); - // put input data - std::vector - data(input->elementsSize()), - ans(output->elementsSize()), - result(ans.size()); - std::iota(data.begin(), data.end(), 0); - gpuIn->copyFromHost(data.data(), input->bytesSize()); - // inference { - void const *inputs[]{*gpuIn}; - void *outputs[]{*gpuOut}; - routine(res, nullptr, inputs, outputs); + // build routine + Dimensions dims{ + {5, -2, 3},// 7 -> {5, 3, 1} -> {108, 900, -360} + {2, 3, 2}, // 6 -> {2, 5} -> { 36, 60, 90} + {1, 1, 3}, // 5 -> {1, 2, 3} -> { 18, 6, 30} + {0, 1, 1}, // 1 -> {0} + {0, 1, 2}, // 2 -> {0, 1} + {0, 1, 3}, // 3 -> {0, 1, 2} + }; + auto input = Tensor::share(DataType::F32, Shape{7, 6, 5, 1, 2, 3}), + output = Tensor::share(DataType::F32, Shape{3, 2, 3, 1, 2, 3}); + SliceInfo info(dims, *input); + auto kernel = SliceCuda::build(info); + auto kCpu = SliceCpu::build(info); + ASSERT_TRUE(kernel && kCpu); + auto res = runtime::Resources(); + auto routine = kernel->lower(res).routine; + auto rCpu = kCpu->lower(res).routine; + // malloc + auto &dev = *device::init(Device::Type::Nvidia, 0, ""); + auto gpuIn = dev.malloc(input->bytesSize()), + gpuOut = dev.malloc(output->bytesSize()); + // put input data + std::vector + data(input->elementsSize()), + ans(output->elementsSize()), + result(ans.size()); + std::iota(data.begin(), data.end(), 0); + gpuIn->copyFromHost(data.data(), input->bytesSize()); + // inference + { + void const *inputs[]{*gpuIn}; + void *outputs[]{*gpuOut}; + routine(res, nullptr, inputs, outputs); + } + { + void const *inputs[]{data.data()}; + void *outputs[]{ans.data()}; + rCpu(res, nullptr, inputs, outputs); + } + // check + gpuOut->copyToHost(result.data(), output->bytesSize()); + EXPECT_EQ(result, ans); } { - void const *inputs[]{data.data()}; - void *outputs[]{ans.data()}; - rCpu(res, nullptr, inputs, outputs); + // build routine + Dimensions dims{ + {0, 1, 7}, + {0, 1, 6}, + {0, 1, 5}, + {0, 1, 1}, + {0, 1, 2}, + {0, 1, 3}, + }; + auto input = Tensor::share(DataType::F32, Shape{7, 6, 5, 1, 2, 3}), + output = Tensor::share(DataType::F32, Shape{7, 6, 5, 1, 2, 3}); + SliceInfo info(dims, *input); + auto kernel = SliceCuda::build(info); + auto kCpu = SliceCpu::build(info); + ASSERT_TRUE(kernel && kCpu); + auto res = runtime::Resources(); + auto routine = kernel->lower(res).routine; + auto rCpu = kCpu->lower(res).routine; + // malloc + auto &dev = *device::init(Device::Type::Nvidia, 0, ""); + auto gpuIn = dev.malloc(input->bytesSize()), + gpuOut = dev.malloc(output->bytesSize()); + // put input data + std::vector + data(input->elementsSize()), + ans(output->elementsSize()), + result(ans.size()); + std::iota(data.begin(), data.end(), 0); + gpuIn->copyFromHost(data.data(), input->bytesSize()); + // inference + { + void const *inputs[]{*gpuIn}; + void *outputs[]{*gpuOut}; + routine(res, nullptr, inputs, outputs); + } + { + void const *inputs[]{data.data()}; + void *outputs[]{ans.data()}; + rCpu(res, nullptr, inputs, outputs); + } + // check + gpuOut->copyToHost(result.data(), output->bytesSize()); + EXPECT_EQ(result, ans); } - // check - gpuOut->copyToHost(result.data(), output->bytesSize()); - EXPECT_EQ(result, ans); } #endif diff --git a/src/05computation/include/computation/operators/pad.h b/src/05computation/include/computation/operators/pad.h new file mode 100644 index 00000000..173fcae7 --- /dev/null +++ b/src/05computation/include/computation/operators/pad.h @@ -0,0 +1,26 @@ +#ifndef COMPUTATION_PAD_H +#define COMPUTATION_PAD_H + +#include "../operator.h" +#include "kernel/collectors/pad.h" + +namespace refactor::computation { + using kernel::PadType; + using Dimensions = kernel::PadDimension; + + struct Pad final : public LayoutDependentOperator { + Dimensions dims; + PadType mode; + + Pad(decltype(dims), PadType) noexcept; + + static size_t typeId() noexcept; + size_t opTypeId() const noexcept final; + std::string_view name() const noexcept final; + kernel::CollectorBox candidateKernels(Target) const noexcept final; + std::string serialize() const noexcept final; + }; + +}// namespace refactor::computation + +#endif// COMPUTATION_PAD_H diff --git a/src/05computation/src/operators/pad.cc b/src/05computation/src/operators/pad.cc new file mode 100644 index 00000000..243f8536 --- /dev/null +++ b/src/05computation/src/operators/pad.cc @@ -0,0 +1,30 @@ +#include "computation/operators/pad.h" +#include "kernel/attributes/pad_info.h" + +namespace refactor::computation { + using Op = Pad; + + Op::Pad(decltype(dims) dims_, + PadType mode_) noexcept : LayoutDependentOperator(), dims(std::move(dims_)), mode(mode_) {} + + auto Op::typeId() noexcept -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + auto Op::opTypeId() const noexcept -> size_t { return typeId(); } + auto Op::name() const noexcept -> std::string_view { return "Pad"; } + auto Op::candidateKernels(Target target) const noexcept -> kernel::CollectorBox { + using Collector_ = kernel::PadCollector; + return std::make_unique(target, std::move(dims), mode); + } + auto Op::serialize() const noexcept -> std::string { + std::stringstream ss; + ss << name() << "(["; + for (auto const &d : dims) { + ss << "input = " << d.dimI << ", output = " << d.dimO << ", pads = " << d.pads; + } + ss << "mode = " << mode.toString() << " ])"; + return ss.str(); + } + +}// namespace refactor::computation diff --git a/src/07onnx/src/operators.cpp b/src/07onnx/src/operators.cpp index a565a8d3..ddfc0066 100644 --- a/src/07onnx/src/operators.cpp +++ b/src/07onnx/src/operators.cpp @@ -20,6 +20,7 @@ #include "operators/hard_sigmoid.hh" #include "operators/mat_mul.hh" #include "operators/mat_mul_integer.hh" +#include "operators/pad.hh" #include "operators/pool.hh" #include "operators/range.hh" #include "operators/reduce.hh" @@ -128,6 +129,7 @@ namespace refactor::onnx { REGISTER(Unsqueeze , Unsqueeze ); REGISTER(Where , Where ); REGISTER(HardSigmoid , HardSigmoid ); + REGISTER(Pad , Pad ); #undef REGISTER // clang-format on } diff --git a/src/07onnx/src/operators/pad.cc b/src/07onnx/src/operators/pad.cc new file mode 100644 index 00000000..c61f0812 --- /dev/null +++ b/src/07onnx/src/operators/pad.cc @@ -0,0 +1,152 @@ +#include "pad.hh" +#include "common.h" +#include "computation/operators/pad.h" +#include + +namespace refactor::onnx { + using Op = Pad; + using Pm = PadMode; + + Op::Pad(Pm mode_) : Operator(), mode(mode_) {} + + auto Op::build(ModelContext const &, std::string_view opType, Attributes attributes) -> OpBox { + auto mode = attributes.getOrInsert("mode", {"constant"}).string(); + Pm pm; + if (mode == "constant") { + pm = Pm::Constant; + } else if (mode == "reflect") { + pm = Pm::Reflect; + } else if (mode == "edge") { + pm = Pm::Edge; + } else if (mode == "wrap") { + pm = Pm::Wrap; + } else { + UNREACHABLEX(void, "Unsupported Pad mode: {}", mode); + } + return OpBox(std::make_unique(pm)); + } + auto Op::typeId() -> size_t { + static uint8_t ID = 1; + return reinterpret_cast(&ID); + } + + auto Op::opTypeId() const -> size_t { return typeId(); } + auto Op::opTypeName() const -> std::string_view { return "onnx::Pad"; } + + auto Op::infer(TensorRefs inputs, InferOptions const &) const -> InferResult { + if (inputs.empty() || inputs.size() > 4 || inputs.size() < 2) { + return Err(InferError(ERROR_MSG("Input size error"))); + } + auto const &input = inputs[0]; + auto const &pad = inputs[1]; + if (pad.dataType != DataType::I64 || pad.rank() != 1) { + return Err(InferError(ERROR_MSG("Pad inputs pads is invalid"))); + } + EXPECT_VAL(pad.shape[0], pad_len) + if (!pad.data) { + return Err(InferError(ERROR_MSG("Pad inputs pads must be constant"))); + } + int64_t const *pads = pad.data->get(); + Ints pads_; + // TODO: onnx padOp inputs pads support negative numbers + for (auto i : range0_(pad.shape[0].value())) { + if (auto pad_value = pads[i]; pad_value < 0) { + return Err(InferError(ERROR_MSG("Pad inputs pads is not support negative numbers"))); + } + pads_.push_back(pads[i]); + } + auto rank = input.rank(); + if (inputs.size() >= 3) { + auto const &constant_value = inputs[2]; + if (constant_value.dataType != input.dataType) { + return Err(InferError(ERROR_MSG("Pad inputs constant type not support"))); + } + } + if (inputs.size() == 4) { + Ints pads__(2 * rank, 0); + auto const &axes = inputs[3]; + if ((axes.dataType != DataType::I32 && axes.dataType != DataType::I64) || axes.rank() != 1) { + return Err(InferError(ERROR_MSG("Pad inputs axes is invalid"))); + } + if (!axes.data) { + return Err(InferError(ERROR_MSG("Pad inputs axes must be constant"))); + } + EXPECT_VAL(axes.shape[0], axes_len) + if (pad_len != 2 * axes_len) { + return Err(InferError(ERROR_MSG("Pad inputs pads len is not 2x axes"))); + } + void const *axes_data = axes.data->get(); + for (auto i : range0_(axes_len)) { + auto axis = axes.dataType == DataType::I32 ? reinterpret_cast(axes_data)[i] : reinterpret_cast(axes_data)[i]; + if (axis < 0) { + axis += rank; + } + if (axis < 0 || axis >= rank) { + return Err(InferError(ERROR_MSG("Axes not support"))); + } + pads__[axis] = pads_[i]; + pads__[axis + rank] = pads_[i + axes_len]; + } + pads_ = pads__; + } else { + if (pad_len != 2 * rank) { + return Err(InferError(ERROR_MSG("Pad inputs pads len is not 2x input"))); + } + } + Shape output_shape(rank, DimExpr(1)); + for (auto i : range0_(rank)) { + output_shape[i] = DimExpr(input.shape[i].value() + pads_[i] + pads_[i + rank]); + } + auto ans = Tensor::share(input.dataType, output_shape, extractDependency(inputs)); + return Ok(Tensors{std::move(ans)}); + } + + auto Op::lower(TensorRefs inputs) const -> computation::OpBox { + using Ty_ = computation::PadType; + using Op_ = computation::Pad; + using Dimension = computation::Dimensions; + + auto rank = inputs[0].rank(); + int64_t const *pads_ = inputs[1].data->get(); + std::vector pads_info(2 * rank, 0); + Dimension dims(rank); + if (inputs.size() != 4) { + for (auto i : range0_(inputs[1].shape[0].value())) { pads_info[i] = pads_[i]; } + } else { + auto const &axes_ = inputs[3]; + void const *axes_data = axes_.data->get(); + auto axes_len = axes_.shape[0].value(); + + for (auto i : range0_(axes_len)) { + auto axis = axes_.dataType == DataType::I32 ? reinterpret_cast(axes_data)[i] : reinterpret_cast(axes_data)[i]; + if (axis < 0) { axis += rank; } + pads_info[axis] = pads_[i]; + pads_info[axis + rank] = pads_[i + axes_len]; + } + } + for (auto i : range0_(rank)) { + auto dimI = inputs[0].shape[i].value(); + dims[i] = { + dimI, dimI + pads_info[i] + pads_info[i + rank], pads_info[i]}; + } + Ty_ mode_; + switch (mode) { + case Pm::Constant: + mode_ = Ty_::Constant; + break; + case Pm::Reflect: + mode_ = Ty_::Reflect; + break; + case Pm::Edge: + mode_ = Ty_::Edge; + break; + case Pm::Wrap: + mode_ = Ty_::Wrap; + break; + default: + UNREACHABLE(); + } + return std::make_unique(std::move(dims), mode_); + } + +}// namespace refactor::onnx diff --git a/src/07onnx/src/operators/pad.hh b/src/07onnx/src/operators/pad.hh new file mode 100644 index 00000000..d3ab28b6 --- /dev/null +++ b/src/07onnx/src/operators/pad.hh @@ -0,0 +1,31 @@ +#ifndef ONNX_PAD_HH +#define ONNX_PAD_HH + +#include "frontend/operator.h" + +namespace refactor::onnx { + using namespace frontend; + + enum class PadMode { + Constant, + Reflect, + Edge, + Wrap, + }; + + struct Pad final : public Operator { + PadMode mode; + + Pad(PadMode); + + static OpBox build(ModelContext const &, std::string_view, Attributes); + static size_t typeId(); + size_t opTypeId() const final; + std::string_view opTypeName() const final; + InferResult infer(TensorRefs, InferOptions const &) const final; + computation::OpBox lower(TensorRefs) const final; + }; + +}// namespace refactor::onnx + +#endif// ONNX_PAD_HH diff --git a/src/07onnx/test/test_pad.cpp b/src/07onnx/test/test_pad.cpp new file mode 100644 index 00000000..6492e2bc --- /dev/null +++ b/src/07onnx/test/test_pad.cpp @@ -0,0 +1,47 @@ +#include "../src/operators/pad.hh" +#include "onnx/operators.h" +#include + +using namespace refactor; +using namespace onnx; + +TEST(infer, Pad) { + onnx::register_(); + + { + auto edges = Edges{ + {Tensor::share(DataType::F32, Shape{DimExpr(2), DimExpr(3), DimExpr(4)}, {}), ""}, + {Tensor::share(DataType::I64, Shape{DimExpr(6)}, {}), ""}, + }; + auto ptr = reinterpret_cast(edges[1].tensor->malloc()); + std::fill(ptr, ptr + edges[1].tensor->elementsSize(), 1); + count_t inputs[]{0, 1}; + auto infered = Pad(PadMode::Constant).infer(TensorRefs(edges, inputs), {false}); + ASSERT_TRUE(infered.isOk()); + auto outputs = std::move(infered.unwrap()); + ASSERT_EQ(outputs.size(), 1); + auto y = std::move(outputs[0]); + ASSERT_EQ(y->dataType, DataType::F32); + ASSERT_EQ(y->shape, (Shape{DimExpr(4), DimExpr(5), DimExpr(6)})); + } + { + auto edges = Edges{ + {Tensor::share(DataType::F32, Shape{DimExpr(2), DimExpr(3), DimExpr(4)}, {}), ""}, + {Tensor::share(DataType::I64, Shape{DimExpr(2)}, {}), ""}, + {Tensor::share(DataType::F32, Shape{DimExpr(1)}, {}), ""}, + {Tensor::share(DataType::I32, Shape{DimExpr(1)}, {}), ""}, + }; + auto ptr_pad = reinterpret_cast(edges[1].tensor->malloc()); + std::fill(ptr_pad, ptr_pad + edges[1].tensor->elementsSize(), 1); + auto ptr_axes = reinterpret_cast(edges[3].tensor->malloc()); + std::fill(ptr_axes, ptr_axes + edges[3].tensor->elementsSize(), 1); + count_t inputs[]{0, 1, 2, 3}; + auto infered = Pad(PadMode::Constant).infer(TensorRefs(edges, inputs), {false}); + ASSERT_TRUE(infered.isOk()); + auto outputs = std::move(infered.unwrap()); + ASSERT_EQ(outputs.size(), 1); + auto y = std::move(outputs[0]); + ASSERT_EQ(y->dataType, DataType::F32); + ASSERT_EQ(y->shape, (Shape{DimExpr(2), DimExpr(5), DimExpr(4)})); + } +}