Skip to content

Commit

Permalink
roi_align_rotated_v2
Browse files Browse the repository at this point in the history
  • Loading branch information
ason-rob committed Nov 26, 2024
1 parent 2b7def6 commit 818ce46
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 39 deletions.
31 changes: 21 additions & 10 deletions mmcv/ops/csrc/pytorch/npu/roi_align_rotated_v2_npu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,36 @@
using namespace NPU_NAME_SPACE;
using namespace std;

void roi_align_rotated_v2_forward_npu(const Tensor input, Tensor rois_map,
Tensor output,
void roi_align_rotated_v2_forward_npu(const Tensor x, Tensor rois_map,
Tensor y,
int32_t pooled_h,
int32_t pooled_w,
double spatial_scale,
int32_t sampling_ratio,
int32_t pooled_height,
int32_t pooled_width,
bool aligned,
bool clockwise) {
at::Tensor feature_map = input.permute({0, 2, 3, 1}).contiguous();
at::Tensor feature_map = x.permute({0, 2, 3, 1}).contiguous();
at::Tensor rois = rois_map.permute({1, 0}).contiguous();
EXEC_NPU_CMD(aclnnRoiAlignRotatedV2, feature_map, rois, spatial_scale, sampling_ratio, pooled_height, pooled_width, aligned, clockwise, output);
at_npu::native::OpCommand cmd;
cmd.Name("RoiAlignRotated")
.Input(feature_map)
.Input(rois)
.Output(y)
.Attr("pooled_h", static_cast<int64_t>(pooled_h))
.Attr("pooled_w", static_cast<int64_t>(pooled_w))
.Attr("spatial_scale", static_cast<float>(spatial_scale))
.Attr("sampling_ratio", static_cast<int64_t>(sampling_ratio))
.Attr("aligned", aligned)
.Attr("clockwise", clockwise)
.Run();
}

void roi_align_rotated_v2_forward_impl(const Tensor input, Tensor rois,
Tensor output,
void roi_align_rotated_v2_forward_impl(const Tensor x, Tensor rois,
Tensor y,
int32_t pooled_h,
int32_t pooled_w,
double spatial_scale,
int32_t sampling_ratio,
int32_t pooled_height,
int32_t pooled_width,
bool aligned,
bool clockwise);

Expand Down
60 changes: 31 additions & 29 deletions mmcv/ops/roi_align_rotated_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,55 +14,55 @@
class RoIAlignRotatedV2Function(Function):

@staticmethod
def symbolic(g, input, rois, spatial_scale, sampling_ratio, pooled_height,
pooled_width, aligned, clockwise):
def symbolic(g, x, rois, spatial_scale, sampling_ratio, pooled_h,
pooled_w, aligned, clockwise):
return g.op(
'mmcv::MMCVRoIAlignRotatedV2',
input,
x,
rois,
pooled_h=pooled_h,
pooled_w=pooled_w,
spatial_scale_f=spatial_scale,
sampling_ratio_i=sampling_ratio,
pooled_height=pooled_height,
pooled_width=pooled_width,
aligned_i=aligned,
clockwise_i=clockwise)

@staticmethod
def forward(ctx: Any,
input: torch.Tensor,
x: torch.Tensor,
rois: torch.Tensor,
pooled_h: int,
pooled_w: int,
spatial_scale: float,
sampling_ratio: int,
pooled_height: int,
pooled_width: int,
aligned: bool = True,
clockwise: bool = False) -> torch.Tensor:
ctx.pooled_height = pooled_height
ctx.pooled_width = pooled_width
ctx.pooled_h = pooled_h
ctx.pooled_w = pooled_w
ctx.spatial_scale = spatial_scale
ctx.sampling_ratio = sampling_ratio
ctx.aligned = aligned
ctx.clockwise = clockwise
ctx.save_for_backward(input, rois)
ctx.feature_size = input.size()
batch_size, num_channels, data_height, data_width = input.size()
ctx.save_for_backward(x, rois)
ctx.feature_size = x.size()
batch_size, num_channels, data_height, data_width = x.size()
num_rois = rois.size(0)

output = input.new_zeros(num_rois, ctx.pooled_height, ctx.pooled_width,
y = x.new_zeros(num_rois, ctx.pooled_h, ctx.pooled_w,
num_channels)

ext_module.roi_align_rotated_v2_forward(
input,
x,
rois,
output,
y,
pooled_h=ctx.pooled_h,
pooled_w=ctx.pooled_w,
spatial_scale=ctx.spatial_scale,
sampling_ratio=ctx.sampling_ratio,
pooled_height=ctx.pooled_height,
pooled_width=ctx.pooled_width,
aligned=ctx.aligned,
clockwise=ctx.clockwise)
output = output.transpose(2, 3).transpose(1, 2).contiguous()
return output
y = y.transpose(2, 3).transpose(1, 2).contiguous()
return y

@staticmethod
def backward(ctx: Any, grad_output: torch.Tensor):
Expand All @@ -74,7 +74,7 @@ def backward(ctx: Any, grad_output: torch.Tensor):
input.size(0), input.size(2), input.size(3), input.size(1))
ext_module.roi_align_rotated_v2_backward(
input, rois_trans, grad_output_trans, grad_input,
ctx.pooled_height, ctx.pooled_width, ctx.spatial_scale,
ctx.pooled_h, ctx.pooled_w, ctx.spatial_scale,
ctx.sampling_ratio, ctx.aligned, ctx.clockwise)
grad_input = grad_input.permute(0, 3, 1, 2).contiguous()

Expand Down Expand Up @@ -134,31 +134,33 @@ class RoIAlignRotatedV2(nn.Module):
},
cls_name='RoIAlignRotatedV2')
def __init__(self,
pooled_h: int,
pooled_w: int,
spatial_scale: float,
sampling_ratio: int,
pooled_height: int,
pooled_width: int,
aligned: bool = True,
clockwise: bool = False):
super().__init__()

self.pooled_height = int(pooled_height)
self.pooled_width = int(pooled_width)
self.pooled_h = int(pooled_h)
self.pooled_w = int(pooled_w)
self.spatial_scale = float(spatial_scale)
self.sampling_ratio = int(sampling_ratio)
self.aligned = aligned
self.clockwise = clockwise

def forward(self, input: torch.Tensor, rois: torch.Tensor) -> torch.Tensor:
return RoIAlignRotatedV2Function.apply(input, rois, self.spatial_scale,
return RoIAlignRotatedV2Function.apply(input, rois,
self.pooled_h,
self.pooled_w,
self.spatial_scale,
self.sampling_ratio,
self.pooled_height,
self.pooled_width, self.aligned,
self.aligned,
self.clockwise)

def __repr__(self):
s = self.__class__.__name__
s += f'(pooled_height={self.pooled_height}, '
s += f'(pooled_h={self.pooled_h}, '
s += f'spatial_scale={self.spatial_scale}, '
s += f'sampling_ratio={self.sampling_ratio}, '
s += f'aligned={self.aligned}, '
Expand Down

0 comments on commit 818ce46

Please sign in to comment.