Skip to content
This repository has been archived by the owner on Jun 8, 2023. It is now read-only.

Commit

Permalink
add backward op def of linear and cross_entropy, add fill def
Browse files Browse the repository at this point in the history
  • Loading branch information
xinhaitong authored and HU-qingqing committed Nov 18, 2022
1 parent 167ecb8 commit 1ea9ce1
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions include/diopi/functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,9 @@ DIOPI_API diopiError_t diopiSigmoidFocalLossBackward(diopiContextHandle_t ctx, d
DIOPI_API diopiError_t diopiCrossEntropyLoss(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input,
diopiConstTensorHandle_t target, diopiConstTensorHandle_t weight, diopiReduction_t reduction,
int64_t ignore_index, double label_smoothing);
DIOPI_API diopiError_t diopiCrossEntropyLossBackward(diopiContextHandle_t ctx, diopiTensorHandle_t grad_input, diopiConstTensorHandle_t grad_output,
diopiConstTensorHandle_t input, diopiConstTensorHandle_t target, diopiConstTensorHandle_t weight,
diopiReduction_t reduction, int64_t ignore_index, double label_smoothing);

/**
* \brief Measures thee nll loss between the target and input probabilities.
Expand Down Expand Up @@ -307,6 +310,11 @@ DIOPI_API diopiError_t diopiClampMinInp(diopiContextHandle_t ctx, diopiTensorHan
DIOPI_API diopiError_t diopiClampMinScalar(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, const diopiScalar_t* min);
DIOPI_API diopiError_t diopiClampMin(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input, diopiConstTensorHandle_t min);

/**
* \brief Fills elements of self tensor with value.
*/
DIOPI_API diopiError_t diopiFill(diopiContextHandle_t ctx, diopiTensorHandle_t input, const diopiScalar_t* value);

/**
* \brief Computes the element-wise logical AND/OR of the given input tensors.
*/
Expand Down Expand Up @@ -445,6 +453,8 @@ DIOPI_API diopiError_t diopiNonzero(diopiContextHandle_t ctx, diopiTensorHandle_
*/
DIOPI_API diopiError_t diopiLinear(diopiContextHandle_t ctx, diopiTensorHandle_t out, diopiConstTensorHandle_t input,
diopiConstTensorHandle_t weight, diopiConstTensorHandle_t bias);
DIOPI_API diopiError_t diopiLinearBackward(diopiContextHandle_t ctx, diopiTensorHandle_t grad_input, diopiTensorHandle_t grad_weight, diopiTensorHandle_t grad_bias,
diopiConstTensorHandle_t grad_output, diopiConstTensorHandle_t input, diopiConstTensorHandle_t weight);

/**
* \brief
Expand Down Expand Up @@ -806,6 +816,7 @@ DIOPI_API diopiError_t diopiUpsampleLinear(diopiContextHandle_t ctx, diopiTensor
bool align_corners, const char* mode);
DIOPI_API diopiError_t diopiUpsampleLinearBackward(diopiContextHandle_t ctx, diopiTensorHandle_t grad_input, diopiConstTensorHandle_t grad_output,
diopiSize_t out_size, diopiSize_t in_size, bool align_corners, const char* mode);

#if defined(__cplusplus)
}
#endif // __cplusplus
Expand Down

0 comments on commit 1ea9ce1

Please sign in to comment.