Skip to content

Commit

Permalink
Update Pose head (#107)
Browse files Browse the repository at this point in the history
  • Loading branch information
HonzaCuhel authored Aug 30, 2024
1 parent 22a3f11 commit e602107
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
22 changes: 12 additions & 10 deletions tools/modules/heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@ def make_anchors(feats, strides, grid_cell_offset=0.5):
sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset # shift x
sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift y
sy, sx = torch.meshgrid(sy, sx, indexing="ij")
anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))
stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
return torch.cat(anchor_points), torch.cat(stride_tensor)
anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2).transpose(0, 1))
stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device).transpose(0, 1))
return anchor_points, stride_tensor
# return torch.cat(anchor_points), torch.cat(stride_tensor)


class DetectV5(nn.Module):
Expand Down Expand Up @@ -432,26 +433,27 @@ def forward(self, x):
"""Perform forward pass through YOLO model and return predictions."""
bs = x[0].shape[0] # batch size
if self.shape != bs:
self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
self.anchors, self.strides = make_anchors(x, self.stride, 0.5)
self.shape = bs

# Detection part
outputs = super().forward(x)

# Pose part
kpt = torch.cat([self.cv4[i](x[i]).view(bs, self.nk, -1) for i in range(self.nl)], -1) # (bs, 17*3, h*w)
pred_kpt = self.kpts_decode(bs, kpt)
outputs.append(pred_kpt)
for i in range(self.nl):
kpt = self.cv4[i](x[i]).view(bs, self.nk, -1)
outputs.append(self.kpts_decode(bs, kpt, i))

return outputs

def kpts_decode(self, bs, kpts):
def kpts_decode(self, bs, kpts, i):
"""Decodes keypoints."""
ndim = self.kpt_shape[1]
y = kpts.view(bs, *self.kpt_shape, -1)
a = (y[:, :, :2] * 2.0 + (self.anchors - 0.5)) * self.strides
a = (y[:, :, :2] * 2.0 + (self.anchors[i] - 0.5)) * self.strides[i]
if ndim == 3:
a = torch.cat((a, y[:, :, 2:3].sigmoid()), 2)
# a = torch.cat((a, y[:, :, 2:3].sigmoid()*10), 2)
a = torch.cat((a, y[:, :, 2:3]), 2)
return a.view(bs, self.nk, -1)


Expand Down
2 changes: 1 addition & 1 deletion tools/yolo/yolov8_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def get_output_names(mode: int) -> List[str]:
elif mode == OBB_MODE:
return ["output1_yolov8", "output2_yolov8", "output3_yolov8", "angle_output"]
elif mode == POSE_MODE:
return ["output1_yolov8", "output2_yolov8", "output3_yolov8", "kpt_output"]
return ["output1_yolov8", "output2_yolov8", "output3_yolov8", "kpt_output1", "kpt_output2", "kpt_output3"]
return ["output"]


Expand Down

0 comments on commit e602107

Please sign in to comment.