Skip to content

Commit

Permalink
Support roi_align
Browse files Browse the repository at this point in the history
  • Loading branch information
CokeDong committed Oct 20, 2023
1 parent 8523eee commit 86ae1ec
Showing 1 changed file with 30 additions and 9 deletions.
39 changes: 30 additions & 9 deletions mmcv/ops/csrc/pytorch/roi_align.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@

#include "csrc_dipu/base/basedef.h"
#include "csrc_dipu/diopirt/diopirt_impl.h"
#include <torch/csrc/utils/pybind.h>
#include "csrc_dipu/utils/helpfunc.hpp"
#include "csrc_dipu/runtime/device/deviceapis.h"

using dipu::diopi_helper::toDiopiScalar;
using dipu::diopi_helper::toDiopiTensorHandle;
using dipu::VENDOR_TYPE;
#endif

void roi_align_forward_impl(Tensor input, Tensor rois, Tensor output,
Expand Down Expand Up @@ -56,10 +60,18 @@ void roi_align_forward_diopi(Tensor input, Tensor rois, Tensor output,
auto argmax_x_p = toDiopiTensorHandle(argmax_x);
bool is_mock_cuda = input.device().type() == dipu::DIPU_DEVICE_TYPE;
if (is_mock_cuda && reinterpret_cast<void *>(diopiRoiAlignMmcv) != nullptr) {
auto ret = diopiRoiAlignMmcv(
ch, out_p, argmax_y_p, argmax_x_p, input_p, rois_p, aligned_height,
aligned_width, sampling_ratio, pool_mode, spatial_scale, aligned);
if (ret == diopiSuccess) return;
if (strcmp(dipu::VendorTypeToStr(VENDOR_TYPE), "NPU") == 0) {
pybind11::gil_scoped_release no_gil;
auto ret = diopiRoiAlignMmcv(
ch, out_p, argmax_y_p, argmax_x_p, input_p, rois_p, aligned_height,
aligned_width, sampling_ratio, pool_mode, spatial_scale, aligned);
if (ret == diopiSuccess) return;
} else {
auto ret = diopiRoiAlignMmcv(
ch, out_p, argmax_y_p, argmax_x_p, input_p, rois_p, aligned_height,
aligned_width, sampling_ratio, pool_mode, spatial_scale, aligned);
if (ret == diopiSuccess) return;
}
}
LOG(WARNING) << "Fallback to cpu: mmcv ext op roi_align_forward";
auto input_cpu = input.cpu();
Expand Down Expand Up @@ -96,11 +108,20 @@ void roi_align_backward_diopi(Tensor grad_output, Tensor rois, Tensor argmax_y,
bool is_mock_cuda = grad_output.device().type() == dipu::DIPU_DEVICE_TYPE;
if (is_mock_cuda &&
reinterpret_cast<void *>(diopiRoiAlignBackwardMmcv) != nullptr) {
auto ret = diopiRoiAlignBackwardMmcv(ch, grad_input_, grad_output_, rois_,
argmax_y_, argmax_x_, aligned_height,
aligned_width, sampling_ratio,
pool_mode, spatial_scale, aligned);
if (ret == diopiSuccess) return;
if (strcmp(dipu::VendorTypeToStr(VENDOR_TYPE), "NPU") == 0) {
pybind11::gil_scoped_release no_gil;
auto ret = diopiRoiAlignBackwardMmcv(ch, grad_input_, grad_output_, rois_,
argmax_y_, argmax_x_, aligned_height,
aligned_width, sampling_ratio,
pool_mode, spatial_scale, aligned);
if (ret == diopiSuccess) return;
} else {
auto ret = diopiRoiAlignBackwardMmcv(ch, grad_input_, grad_output_, rois_,
argmax_y_, argmax_x_, aligned_height,
aligned_width, sampling_ratio,
pool_mode, spatial_scale, aligned);
if (ret == diopiSuccess) return;
}
}
LOG(WARNING) << "Fallback to cpu: mmcv ext op roi_align_backward";
auto grad_output_cpu = grad_output.cpu();
Expand Down

0 comments on commit 86ae1ec

Please sign in to comment.