-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathlosses.py
125 lines (98 loc) · 4.74 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
import pickle
import numpy as np
import torch
import torch.nn as nn
from common import constants
def perspective_projection(points,
rotation,
translation,
focal_length,
camera_center):
"""
This function computes the perspective projection of a set of points.
Input:
points (bs, N, 3): 3D points
rotation (bs, 3, 3): Camera rotation
translation (bs, 3): Camera translation
focal_length (bs,) or scalar: Focal length
camera_center (bs, 2): Camera center
"""
batch_size = points.shape[0]
K = torch.zeros([batch_size, 3, 3], device=points.device)
K[:, 0, 0] = focal_length
K[:, 1, 1] = focal_length
K[:, 2, 2] = 1.
K[:, :-1, -1] = camera_center
# Transform points
points = torch.einsum('bij,bkj->bki', rotation, points)
# Apply perspective distortion
projected_points = points / points[:, :, -1].unsqueeze(-1)
# Apply camera intrinsics
projected_points = torch.einsum('bij,bkj->bki', K, projected_points)
return projected_points[:, :, :-1]
def gmof(x, sigma):
"""
Geman-McClure error function
"""
x_squared = x ** 2
sigma_squared = sigma ** 2
return (sigma_squared * x_squared) / (sigma_squared + x_squared)
def angle_prior(pose):
"""
Angle prior that penalizes unnatural bending of the knees and elbows
"""
# We subtract 3 because pose does not include the global rotation of the model
return torch.exp(pose[:, [55-3, 58-3, 12-3, 15-3]] * torch.tensor([1., -1., -1, -1.], device=pose.device)) ** 2
def body_fitting_loss(body_pose, betas, model_joints, camera_t, camera_center,
joints_2d, joints_conf, pose_prior,
focal_length=5000, sigma=100, pose_prior_weight=4.78,
shape_prior_weight=5, angle_prior_weight=15.2,
output='sum'):
"""
Loss function for body fitting
"""
batch_size = body_pose.shape[0]
rotation = torch.eye(3, device=body_pose.device).unsqueeze(0).expand(batch_size, -1, -1)
projected_joints = perspective_projection(model_joints, rotation, camera_t,
focal_length, camera_center)
# Weighted robust reprojection error
reprojection_error = gmof(projected_joints - joints_2d, sigma)
reprojection_loss = (joints_conf ** 2) * reprojection_error.sum(dim=-1)
# Pose prior loss
pose_prior_loss = (pose_prior_weight ** 2) * pose_prior(body_pose, betas)
# Angle prior for knees and elbows
angle_prior_loss = (angle_prior_weight ** 2) * angle_prior(body_pose).sum(dim=-1)
# Regularizer to prevent betas from taking large values
shape_prior_loss = (shape_prior_weight ** 2) * (betas ** 2).sum(dim=-1)
total_loss = reprojection_loss.sum(dim=-1) + pose_prior_loss + angle_prior_loss + shape_prior_loss
if output == 'sum':
return total_loss.sum()
elif output == 'reprojection':
return reprojection_loss
def camera_fitting_loss(model_joints, camera_t, camera_t_est, camera_center, joints_2d, joints_conf,
focal_length=5000, depth_loss_weight=100):
"""
Loss function for camera optimization.
"""
# Project model joints
batch_size = model_joints.shape[0]
rotation = torch.eye(3, device=model_joints.device).unsqueeze(0).expand(batch_size, -1, -1)
projected_joints = perspective_projection(model_joints, rotation, camera_t,
focal_length, camera_center)
op_joints = ['OP RHip', 'OP LHip', 'OP RShoulder', 'OP LShoulder']
op_joints_ind = [constants.JOINT_IDS[joint] for joint in op_joints]
gt_joints = ['Right Hip', 'Left Hip', 'Right Shoulder', 'Left Shoulder']
gt_joints_ind = [constants.JOINT_IDS[joint] for joint in gt_joints]
reprojection_error_op = (joints_2d[:, op_joints_ind] -
projected_joints[:, op_joints_ind]) ** 2
reprojection_error_gt = (joints_2d[:, gt_joints_ind] -
projected_joints[:, gt_joints_ind]) ** 2
# Check if for each example in the batch all 4 OpenPose detections are valid, otherwise use the GT detections
# OpenPose joints are more reliable for this task, so we prefer to use them if possible
is_valid = (joints_conf[:, op_joints_ind].min(dim=-1)[0][:,None,None] > 0).float()
reprojection_loss = (is_valid * reprojection_error_op + (1-is_valid) * reprojection_error_gt).sum(dim=(1,2))
# Loss that penalizes deviation from depth estimate
depth_loss = (depth_loss_weight ** 2) * (camera_t[:, 2] - camera_t_est[:, 2]) ** 2
total_loss = reprojection_loss + depth_loss
return total_loss.sum()