Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
CokeDong committed Oct 20, 2023
1 parent 86ae1ec commit f82375a
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 23 deletions.
46 changes: 23 additions & 23 deletions mmcv/ops/csrc/pytorch/roi_align.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@
#include <diopi/diopirt.h>
#include <diopi/functions.h>
#include <diopi/functions_mmcv.h>
#include <torch/csrc/utils/pybind.h>

#include "csrc_dipu/base/basedef.h"
#include "csrc_dipu/diopirt/diopirt_impl.h"
#include <torch/csrc/utils/pybind.h>
#include "csrc_dipu/utils/helpfunc.hpp"
#include "csrc_dipu/runtime/device/deviceapis.h"
#include "csrc_dipu/utils/helpfunc.hpp"

using dipu::VENDOR_TYPE;
using dipu::diopi_helper::toDiopiScalar;
using dipu::diopi_helper::toDiopiTensorHandle;
using dipu::VENDOR_TYPE;
#endif

void roi_align_forward_impl(Tensor input, Tensor rois, Tensor output,
Expand Down Expand Up @@ -61,16 +61,16 @@ void roi_align_forward_diopi(Tensor input, Tensor rois, Tensor output,
bool is_mock_cuda = input.device().type() == dipu::DIPU_DEVICE_TYPE;
if (is_mock_cuda && reinterpret_cast<void *>(diopiRoiAlignMmcv) != nullptr) {
if (strcmp(dipu::VendorTypeToStr(VENDOR_TYPE), "NPU") == 0) {
pybind11::gil_scoped_release no_gil;
auto ret = diopiRoiAlignMmcv(
ch, out_p, argmax_y_p, argmax_x_p, input_p, rois_p, aligned_height,
aligned_width, sampling_ratio, pool_mode, spatial_scale, aligned);
if (ret == diopiSuccess) return;
pybind11::gil_scoped_release no_gil;
auto ret = diopiRoiAlignMmcv(
ch, out_p, argmax_y_p, argmax_x_p, input_p, rois_p, aligned_height,
aligned_width, sampling_ratio, pool_mode, spatial_scale, aligned);
if (ret == diopiSuccess) return;
} else {
auto ret = diopiRoiAlignMmcv(
ch, out_p, argmax_y_p, argmax_x_p, input_p, rois_p, aligned_height,
aligned_width, sampling_ratio, pool_mode, spatial_scale, aligned);
if (ret == diopiSuccess) return;
auto ret = diopiRoiAlignMmcv(
ch, out_p, argmax_y_p, argmax_x_p, input_p, rois_p, aligned_height,
aligned_width, sampling_ratio, pool_mode, spatial_scale, aligned);
if (ret == diopiSuccess) return;
}
}
LOG(WARNING) << "Fallback to cpu: mmcv ext op roi_align_forward";
Expand Down Expand Up @@ -109,18 +109,18 @@ void roi_align_backward_diopi(Tensor grad_output, Tensor rois, Tensor argmax_y,
if (is_mock_cuda &&
reinterpret_cast<void *>(diopiRoiAlignBackwardMmcv) != nullptr) {
if (strcmp(dipu::VendorTypeToStr(VENDOR_TYPE), "NPU") == 0) {
pybind11::gil_scoped_release no_gil;
auto ret = diopiRoiAlignBackwardMmcv(ch, grad_input_, grad_output_, rois_,
argmax_y_, argmax_x_, aligned_height,
aligned_width, sampling_ratio,
pool_mode, spatial_scale, aligned);
if (ret == diopiSuccess) return;
pybind11::gil_scoped_release no_gil;
auto ret = diopiRoiAlignBackwardMmcv(ch, grad_input_, grad_output_, rois_,
argmax_y_, argmax_x_, aligned_height,
aligned_width, sampling_ratio,
pool_mode, spatial_scale, aligned);
if (ret == diopiSuccess) return;
} else {
auto ret = diopiRoiAlignBackwardMmcv(ch, grad_input_, grad_output_, rois_,
argmax_y_, argmax_x_, aligned_height,
aligned_width, sampling_ratio,
pool_mode, spatial_scale, aligned);
if (ret == diopiSuccess) return;
auto ret = diopiRoiAlignBackwardMmcv(ch, grad_input_, grad_output_, rois_,
argmax_y_, argmax_x_, aligned_height,
aligned_width, sampling_ratio,
pool_mode, spatial_scale, aligned);
if (ret == diopiSuccess) return;
}
}
LOG(WARNING) << "Fallback to cpu: mmcv ext op roi_align_backward";
Expand Down
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,12 +244,15 @@ def get_extensions():
dipu_path = os.getenv('DIPU_PATH')
vendor_include_dirs = os.getenv('VENDOR_INCLUDE_DIRS')
nccl_include_dirs = os.getenv('NCCL_INCLUDE_DIRS')
pytorch_dir = os.getenv('PYTORCH_DIR')
include_dirs.append(dipu_root)
include_dirs.append(diopi_path + '/include')
include_dirs.append(dipu_path + '/dist/include')
include_dirs.append(vendor_include_dirs)
if nccl_include_dirs:
include_dirs.append(nccl_include_dirs)
if pytorch_dir:
include_dirs.append(pytorch_dir + 'torch/include')
library_dirs += [dipu_root]
libraries += ['torch_dipu']
elif is_rocm_pytorch or torch.cuda.is_available() or os.getenv(
Expand Down

0 comments on commit f82375a

Please sign in to comment.