From 97bebfb40cac5b5c8df1b19939dcaed77d1dbc30 Mon Sep 17 00:00:00 2001 From: xinhaitong Date: Tue, 27 Dec 2022 13:59:13 +0800 Subject: [PATCH] fix index_put api --- include/diopi/functions.h | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/include/diopi/functions.h b/include/diopi/functions.h index b0bf2fe..71c6eb1 100644 --- a/include/diopi/functions.h +++ b/include/diopi/functions.h @@ -777,8 +777,8 @@ DIOPI_API diopiError_t diopiScatterScalar(diopiContextHandle_t ctx, diopiTensorH /** * \brief Puts values from the tensor values into the tensor input using the indices specified in indices. */ -DIOPI_API diopiError_t diopiIndexPutInp(diopiContextHandle_t ctx, diopiTensorHandle_t input, diopiConstTensorHandle_t values, diopiConstTensorHandle_t* indices, bool accumulate); -DIOPI_API diopiError_t diopiIndexPut(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, diopiConstTensorHandle_t values, diopiConstTensorHandle_t* indices, bool accumulate); +DIOPI_API diopiError_t diopiIndexPutInp(diopiContextHandle_t ctx, diopiTensorHandle_t input, diopiConstTensorHandle_t values, diopiConstTensorHandle_t* indices, int64_t indices_counts, bool accumulate); +DIOPI_API diopiError_t diopiIndexPut(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, diopiConstTensorHandle_t values, diopiConstTensorHandle_t* indices, int64_t indices_counts, bool accumulate); /** * \brief Distribution and random numbers. @@ -817,10 +817,20 @@ DIOPI_API diopiError_t diopiUpsampleLinear(diopiContextHandle_t ctx, diopiTensor DIOPI_API diopiError_t diopiUpsampleLinearBackward(diopiContextHandle_t ctx, diopiTensorHandle_t grad_input, diopiConstTensorHandle_t grad_output, diopiSize_t out_size, diopiSize_t in_size, bool align_corners, const char* mode); +/** + * \brief Computes the inverse error function of input tensor. + */ DIOPI_API diopiError_t diopiErfinv(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input); DIOPI_API diopiError_t diopiErfinvInp(diopiContextHandle_t ctx, diopiTensorHandle_t input); +/** + * \brief Extracts sliding local blocks from a batched input tensor. + */ DIOPI_API diopiError_t diopiIm2Col(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, diopiSize_t kernel_size, diopiSize_t dilation, diopiSize_t padding, diopiSize_t stride); + +/** + * \brief Combines an array of sliding local blocks into a large containing tensor. + */ DIOPI_API diopiError_t diopiCol2Im(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, diopiSize_t output_size, diopiSize_t kernel_size, diopiSize_t dilation, diopiSize_t padding, diopiSize_t stride); #if defined(__cplusplus)