Skip to content
This repository has been archived by the owner on Jun 8, 2023. It is now read-only.

Commit

Permalink
add normal_ / hardswish/ hardswish_ (#21)
Browse files Browse the repository at this point in the history
* add normal_ / hardswish/ hardswish_

* fix formal parameters

* add meshgrid

* update function paras
  • Loading branch information
z379035389 authored Apr 4, 2023
1 parent 75f66b7 commit 2b0edbe
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions include/diopi/functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ DIOPI_API diopiError_t diopiHardtanhInp(diopiContextHandle_t ctx, diopiTensorHan
DIOPI_API diopiError_t diopiHardtanhBackward(diopiContextHandle_t ctx, diopiTensorHandle_t grad_input, diopiConstTensorHandle_t grad_output,
diopiConstTensorHandle_t input, const diopiScalar_t* min_val, const diopiScalar_t* max_val);

DIOPI_API diopiError_t diopiHardswish(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input);
DIOPI_API diopiError_t diopiHardswishInp(diopiContextHandle_t ctx, diopiTensorHandle_t input);

DIOPI_API diopiError_t diopiThreshold(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input,
const diopiScalar_t* threshold, const diopiScalar_t* value);
DIOPI_API diopiError_t diopiThresholdInp(diopiContextHandle_t ctx, diopiTensorHandle_t input, const diopiScalar_t* threshold, const diopiScalar_t* value);
Expand Down Expand Up @@ -862,7 +865,9 @@ DIOPI_API diopiError_t diopiNormal(diopiContextHandle_t ctx, diopiTensorHandle_t
DIOPI_API diopiError_t diopiNormalTensorScalar(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t mean, double std);
DIOPI_API diopiError_t diopiNormalScalarTensor(diopiContextHandle_t ctx, diopiTensorHandle_t out, double mean, diopiConstTensorHandle_t std);
DIOPI_API diopiError_t diopiNormalTensor(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t mean, diopiConstTensorHandle_t std);
DIOPI_API diopiError_t diopiNormalInp(diopiContextHandle_t ctx, diopiTensorHandle_t inout, double mean, double std);

DIOPI_API diopiError_t diopiMeshGrid(diopiContextHandle_t ctx, diopiTensorHandle_t* outs, diopiConstTensorHandle_t* inputs, int64_t inputsNum);
/**
* \brief Applies Layer Normalization over a mini-batch of inputs.
*/
Expand Down

0 comments on commit 2b0edbe

Please sign in to comment.