Skip to content

Commit

Permalink
adapt ext ops for npu device for mmcv
Browse files Browse the repository at this point in the history
  • Loading branch information
CokeDong committed Oct 24, 2023
1 parent f82375a commit bac4da8
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 32 deletions.
30 changes: 24 additions & 6 deletions mmcv/ops/csrc/pytorch/focal_loss.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@
#include <diopi/diopirt.h>
#include <diopi/functions.h>
#include <diopi/functions_mmcv.h>
#include <torch/csrc/utils/pybind.h>

#include "csrc_dipu/diopirt/diopirt_impl.h"
#include "csrc_dipu/runtime/device/deviceapis.h"
#include "csrc_dipu/utils/helpfunc.hpp"

using dipu::VENDOR_TYPE;
using dipu::diopi_helper::toDiopiScalar;
using dipu::diopi_helper::toDiopiTensorHandle;
#endif
Expand Down Expand Up @@ -57,9 +61,16 @@ void sigmoid_focal_loss_forward_diopi(Tensor input, Tensor target,
auto weight_p = toDiopiTensorHandle(weight);
auto output_p = toDiopiTensorHandle(output);
if (reinterpret_cast<void *>(diopiSigmoidFocalLossMmcv) != nullptr) {
auto ret = diopiSigmoidFocalLossMmcv(ch, output_p, input_p, target_p,
weight_p, gamma, alpha);
if (ret == diopiSuccess) return;
if (strcmp(dipu::VendorTypeToStr(VENDOR_TYPE), "NPU") == 0) {
pybind11::gil_scoped_release no_gil;
auto ret = diopiSigmoidFocalLossMmcv(ch, output_p, input_p, target_p,
weight_p, gamma, alpha);
if (ret == diopiSuccess) return;
} else {
auto ret = diopiSigmoidFocalLossMmcv(ch, output_p, input_p, target_p,
weight_p, gamma, alpha);
if (ret == diopiSuccess) return;
}
}
LOG(WARNING)
<< "Fallback to cpu: mmcv ext op sigmoid_focal_loss_forward_impl";
Expand Down Expand Up @@ -90,9 +101,16 @@ void sigmoid_focal_loss_backward_diopi(Tensor input, Tensor target,
auto weight_p = toDiopiTensorHandle(weight);
auto grad_input_p = toDiopiTensorHandle(grad_input);
if (reinterpret_cast<void *>(diopiSigmoidFocalLossBackwardMmcv) != nullptr) {
auto ret = diopiSigmoidFocalLossBackwardMmcv(
ch, grad_input_p, input_p, target_p, weight_p, gamma, alpha);
if (ret == diopiSuccess) return;
if (strcmp(dipu::VendorTypeToStr(VENDOR_TYPE), "NPU") == 0) {
pybind11::gil_scoped_release no_gil;
auto ret = diopiSigmoidFocalLossBackwardMmcv(
ch, grad_input_p, input_p, target_p, weight_p, gamma, alpha);
if (ret == diopiSuccess) return;
} else {
auto ret = diopiSigmoidFocalLossBackwardMmcv(
ch, grad_input_p, input_p, target_p, weight_p, gamma, alpha);
if (ret == diopiSuccess) return;
}
}
LOG(WARNING)
<< "Fallback to cpu: mmcv ext op sigmoid_focal_loss_forward_impl";
Expand Down
47 changes: 36 additions & 11 deletions mmcv/ops/csrc/pytorch/modulated_deform_conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@
#include <diopi/diopirt.h>
#include <diopi/functions.h>
#include <diopi/functions_mmcv.h>
#include <torch/csrc/utils/pybind.h>

#include "csrc_dipu/diopirt/diopirt_impl.h"
#include "csrc_dipu/runtime/device/deviceapis.h"
#include "csrc_dipu/utils/helpfunc.hpp"

using dipu::VENDOR_TYPE;
using dipu::diopi_helper::toDiopiScalar;
using dipu::diopi_helper::toDiopiTensorHandle;
#endif
Expand Down Expand Up @@ -273,11 +277,20 @@ void modulated_deform_conv_forward_diopi(
auto output_p = toDiopiTensorHandle(output);
auto columns_p = toDiopiTensorHandle(columns);
if (reinterpret_cast<void*>(diopiModulatedDeformConvMmcv) != nullptr) {
auto ret = diopiModulatedDeformConvMmcv(
ch, output_p, columns_p, ones_p, input_p, weight_p, bias_p, offset_p,
mask_p, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w,
dilation_h, dilation_w, group, deformable_group, with_bias);
if (ret == diopiSuccess) return;
if (strcmp(dipu::VendorTypeToStr(VENDOR_TYPE), "NPU") == 0) {
pybind11::gil_scoped_release no_gil;
auto ret = diopiModulatedDeformConvMmcv(
ch, output_p, columns_p, ones_p, input_p, weight_p, bias_p, offset_p,
mask_p, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w,
dilation_h, dilation_w, group, deformable_group, with_bias);
if (ret == diopiSuccess) return;
} else {
auto ret = diopiModulatedDeformConvMmcv(
ch, output_p, columns_p, ones_p, input_p, weight_p, bias_p, offset_p,
mask_p, kernel_h, kernel_w, stride_h, stride_w, pad_h, pad_w,
dilation_h, dilation_w, group, deformable_group, with_bias);
if (ret == diopiSuccess) return;
}
}
LOG(WARNING) << "Fallback to cpu: mmcv ext op modulated_deform_conv_forward";
auto input_cpu = input.cpu();
Expand Down Expand Up @@ -331,12 +344,24 @@ void modulated_deform_conv_backward_diopi(

if (reinterpret_cast<void*>(diopiModulatedDeformConvBackwardMmcv) !=
nullptr) {
auto ret = diopiModulatedDeformConvBackwardMmcv(
ch, grad_input_p, grad_weight_p, grad_bias_p, grad_offset_p,
grad_mask_p, input_p, weight_p, bias_p, ones_p, offset_p, mask_p,
columns_p, grad_output_p, kernel_h, kernel_w, stride_h, stride_w, pad_h,
pad_w, dilation_h, dilation_w, group, deformable_group, with_bias);
if (ret == diopiSuccess) return;
if (strcmp(dipu::VendorTypeToStr(VENDOR_TYPE), "NPU") == 0) {
pybind11::gil_scoped_release no_gil;
auto ret = diopiModulatedDeformConvBackwardMmcv(
ch, grad_input_p, grad_weight_p, grad_bias_p, grad_offset_p,
grad_mask_p, input_p, weight_p, bias_p, ones_p, offset_p, mask_p,
columns_p, grad_output_p, kernel_h, kernel_w, stride_h, stride_w,
pad_h, pad_w, dilation_h, dilation_w, group, deformable_group,
with_bias);
if (ret == diopiSuccess) return;
} else {
auto ret = diopiModulatedDeformConvBackwardMmcv(
ch, grad_input_p, grad_weight_p, grad_bias_p, grad_offset_p,
grad_mask_p, input_p, weight_p, bias_p, ones_p, offset_p, mask_p,
columns_p, grad_output_p, kernel_h, kernel_w, stride_h, stride_w,
pad_h, pad_w, dilation_h, dilation_w, group, deformable_group,
with_bias);
if (ret == diopiSuccess) return;
}
}
LOG(WARNING) << "Fallback to cpu: mmcv ext op modulated_deform_conv_forward";
auto input_cpu = input.cpu();
Expand Down
24 changes: 19 additions & 5 deletions mmcv/ops/csrc/pytorch/nms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@
#include <diopi/diopirt.h>
#include <diopi/functions.h>
#include <diopi/functions_mmcv.h>
#include <torch/csrc/utils/pybind.h>

#include "csrc_dipu/base/basedef.h"
#include "csrc_dipu/diopirt/diopirt_impl.h"
#include "csrc_dipu/runtime/device/deviceapis.h"
#include "csrc_dipu/utils/helpfunc.hpp"

using dipu::VENDOR_TYPE;
using dipu::diopi_helper::toDiopiScalar;
using dipu::diopi_helper::toDiopiTensorHandle;
#endif
Expand Down Expand Up @@ -45,11 +49,21 @@ Tensor nms_diopi(Tensor boxes, Tensor scores, float iou_threshold, int offset) {
auto scores_p = toDiopiTensorHandle(scores);
bool is_mock_cuda = boxes.device().type() == dipu::DIPU_DEVICE_TYPE;
if (is_mock_cuda && reinterpret_cast<void*>(diopiNmsMmcv) != nullptr) {
auto ret =
diopiNmsMmcv(ch, outhandle, boxes_p, scores_p, iou_threshold, offset);
if (ret == diopiSuccess) {
auto tensorhandle = reinterpret_cast<Tensor*>(*outhandle);
return *tensorhandle;
if (strcmp(dipu::VendorTypeToStr(VENDOR_TYPE), "NPU") == 0) {
pybind11::gil_scoped_release no_gil;
auto ret =
diopiNmsMmcv(ch, outhandle, boxes_p, scores_p, iou_threshold, offset);
if (ret == diopiSuccess) {
auto tensorhandle = reinterpret_cast<Tensor*>(*outhandle);
return *tensorhandle;
}
} else {
auto ret =
diopiNmsMmcv(ch, outhandle, boxes_p, scores_p, iou_threshold, offset);
if (ret == diopiSuccess) {
auto tensorhandle = reinterpret_cast<Tensor*>(*outhandle);
return *tensorhandle;
}
}
}
LOG(WARNING) << "Fallback to cpu: mmcv ext op nms";
Expand Down
36 changes: 28 additions & 8 deletions mmcv/ops/csrc/pytorch/voxelization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@
#include <diopi/diopirt.h>
#include <diopi/functions.h>
#include <diopi/functions_mmcv.h>
#include <torch/csrc/utils/pybind.h>

#include "csrc_dipu/diopirt/diopirt_impl.h"
#include "csrc_dipu/runtime/device/deviceapis.h"
#include "csrc_dipu/utils/helpfunc.hpp"

using dipu::VENDOR_TYPE;
using dipu::diopi_helper::toDiopiScalar;
using dipu::diopi_helper::toDiopiTensorHandle;
#endif
Expand Down Expand Up @@ -84,11 +88,20 @@ void hard_voxelize_forward_diopi(const at::Tensor &points,
auto num_points_per_voxel_p = toDiopiTensorHandle(num_points_per_voxel);
auto voxel_num_p = toDiopiTensorHandle(voxel_num);
if (reinterpret_cast<void *>(diopiHardVoxelizeMmcv) != nullptr) {
auto ret = diopiHardVoxelizeMmcv(
ch, voxels_p, coors_p, num_points_per_voxel_p, voxel_num_p, points_p,
voxel_size_p, coors_range_p, max_points, max_voxels, NDim,
deterministic);
if (ret == diopiSuccess) return;
if (strcmp(dipu::VendorTypeToStr(VENDOR_TYPE), "NPU") == 0) {
pybind11::gil_scoped_release no_gil;
auto ret = diopiHardVoxelizeMmcv(
ch, voxels_p, coors_p, num_points_per_voxel_p, voxel_num_p, points_p,
voxel_size_p, coors_range_p, max_points, max_voxels, NDim,
deterministic);
if (ret == diopiSuccess) return;
} else {
auto ret = diopiHardVoxelizeMmcv(
ch, voxels_p, coors_p, num_points_per_voxel_p, voxel_num_p, points_p,
voxel_size_p, coors_range_p, max_points, max_voxels, NDim,
deterministic);
if (ret == diopiSuccess) return;
}
}
LOG(WARNING) << "Fallback to cpu: mmcv ext op hard_voxelize_forward";
auto points_cpu = points.cpu();
Expand Down Expand Up @@ -146,9 +159,16 @@ void dynamic_voxelize_forward_diopi(const at::Tensor &points,
auto coors_range_p = toDiopiTensorHandle(coors_range);
auto coors_p = toDiopiTensorHandle(coors);
if (reinterpret_cast<void *>(diopiDynamicVoxelizeMmcv) != nullptr) {
auto ret = diopiDynamicVoxelizeMmcv(ch, coors_p, points_p, voxel_size_p,
coors_range_p, NDim);
if (ret == diopiSuccess) return;
if (strcmp(dipu::VendorTypeToStr(VENDOR_TYPE), "NPU") == 0) {
pybind11::gil_scoped_release no_gil;
auto ret = diopiDynamicVoxelizeMmcv(ch, coors_p, points_p, voxel_size_p,
coors_range_p, NDim);
if (ret == diopiSuccess) return;
} else {
auto ret = diopiDynamicVoxelizeMmcv(ch, coors_p, points_p, voxel_size_p,
coors_range_p, NDim);
if (ret == diopiSuccess) return;
}
}
LOG(WARNING) << "Fallback to cpu: mmcv ext op dynamic_voxelize_forward";
auto points_cpu = points.cpu();
Expand Down
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,10 +249,9 @@ def get_extensions():
include_dirs.append(diopi_path + '/include')
include_dirs.append(dipu_path + '/dist/include')
include_dirs.append(vendor_include_dirs)
include_dirs.append(pytorch_dir + 'torch/include')
if nccl_include_dirs:
include_dirs.append(nccl_include_dirs)
if pytorch_dir:
include_dirs.append(pytorch_dir + 'torch/include')
library_dirs += [dipu_root]
libraries += ['torch_dipu']
elif is_rocm_pytorch or torch.cuda.is_available() or os.getenv(
Expand Down

0 comments on commit bac4da8

Please sign in to comment.